Param#

class brainstate.nn.Param(value, t=IdentityT(), reg=None, precompute=None, fit=True, enable_cache_logging=False)[source]#

A module has neural network parameters for optional transform and regularization.

A flexible parameter container that supports:

  • Bijective transformations for constrained optimization

  • Regularization (L1, L2, Gaussian, etc.)

  • Trainable or fixed parameter modes

  • Automatic caching of transformed values for performance

Parameters:
  • value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Initial parameter value in the constrained space.

  • t (Transform) – Bijective transformation to apply. Default is IdentityT().

  • reg (Regularization | None) – Regularization to apply. Default is None.

  • fit (bool) – Whether the parameter is trainable. Default is True.

  • enable_cache_logging (bool) – Whether to enable INFO-level logging for cache events. Default is False. Logs cache hits, misses, invalidations, and errors for debugging.

fit#

Whether the parameter is trainable.

Type:

bool

t#

The bijective transformation.

Type:

Transform

reg#

The regularization, if any.

Type:

Regularization or None

precompute#

Optional precompute function applied after transformation.

Type:

Callable or None

val#

The internal parameter storage.

Type:

array_like or ParamState

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import Param, SoftplusT, L2Reg
>>> # Trainable positive parameter with L2 regularization
>>> param = Param(
...     jnp.array([1.0, 2.0]),
...     t=SoftplusT(0.0),
...     reg=L2Reg(weight=0.01)
... )
>>> param.value()  # Get constrained value
>>> param.reg_loss()  # Get regularization loss
>>> # Caching is automatic for all parameters
>>> param = Param(
...     jnp.array([1.0, 2.0]),
...     t=SoftplusT()
... )
>>> val1 = param.value()  # Computes and caches
>>> val2 = param.value()  # Returns cached value (fast)
>>> param.set_value(jnp.array([3.0, 4.0]))  # Invalidates cache
>>> val3 = param.value()  # Recomputes and caches

Notes

The internal value is stored in the unconstrained space when a transform is provided. The value() method returns the constrained value after applying the forward transformation.

Caching behavior: The transformed value is cached on first access and automatically invalidated when the parameter is updated (via set_value() or direct state writes). Use clear_cache() for manual invalidation. The caching mechanism is thread-safe using RLock.

cache()[source]#

Manually cache the transformed value.

This method forces immediate computation and caching of the transformed value, even if the cache is already valid. Useful for warming up the cache before performance-critical sections.

Note

The cache is automatically populated on first access to value(). This method is only needed for explicit cache warming.

Example

>>> import jax.numpy as jnp
>>> from brainstate.nn import Param, SoftplusT
>>> param = Param(jnp.array([1.0, 2.0]), t=SoftplusT())
>>> param.cache()  # Warm up cache before performance-critical code
>>> val = param.value()  # Fast - returns cached value
Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

property cache_stats: dict#

Get cache statistics (for debugging/monitoring).

Returns:

Dictionary with keys: valid, has_cached_value

Return type:

dict

Example

>>> import jax.numpy as jnp
>>> from brainstate.nn import Param, SoftplusT
>>> param = Param(jnp.array([1.0]), t=SoftplusT())
>>> param.cache_stats
{'valid': False, 'has_cached_value': False}
>>> _ = param.value()  # Compute and cache
>>> param.cache_stats
{'valid': True, 'has_cached_value': True}
clear_cache()[source]#

Explicitly clear the parameter transformation cache.

This method invalidates any cached transformed value, forcing the next call to value() to recompute the transformation. Thread-safe.

Note

Cache is automatically invalidated when the parameter is updated. This method is primarily useful for manual cache management or debugging.

Example

>>> import jax.numpy as jnp
>>> from brainstate.nn import Param, SoftplusT
>>> param = Param(jnp.array([1.0, 2.0]), t=SoftplusT())
>>> _ = param.value()  # Computes and caches
>>> param.clear_cache()  # Manual invalidation
>>> _ = param.value()  # Recomputes
Return type:

None

clip(min_val=None, max_val=None)[source]#

Clamp parameter value in-place.

Parameters:
  • min_val (float | None) – Minimum value for clipping. Default is None (no lower bound).

  • max_val (float | None) – Maximum value for clipping. Default is None (no upper bound).

classmethod init(data, sizes=None, allow_none=True, **param_kwargs)[source]#

Initialize parameters.

Parameters:
  • data (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | Param) –

    The initialization of the parameter.

    • If it is None, the created parameter will be None.

    • If it is a callable function \(f\), the f(size) will be returned.

    • If it is an instance of init.Initializer`, the f(size) will be returned.

    • If it is a tensor, then this function check whether tensor.shape is equal to the given size.

  • sizes (int | Sequence[int]) – The shape of the parameter.

  • allow_none (bool) – Whether allow the parameter is None.

  • **param_kwargs – Additional keyword arguments passed to the initialization.

Return type:

Param | Const

reg_loss()[source]#

Calculate regularization loss.

Returns:

Regularization loss. Returns 0.0 for fixed parameters or parameters without regularization.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_to_prior()[source]#

Reset parameter value to regularization prior value.

Only has effect if regularization is defined.

set_value(value)[source]#

Set parameter value from constrained space.

The value is transformed to unconstrained space for internal storage. Automatically invalidates cache.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – New value in the constrained space.

value()[source]#

Get current parameter value after applying transform.

Returns cached value when valid. Otherwise, computes t.forward(val), caches it, and returns the result.

Returns:

Parameter value in the constrained space.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity