Param#
- class brainstate.nn.Param(value, t=IdentityT(), reg=None, precompute=None, fit=True, enable_cache_logging=False)[source]#
A module has neural network parameters for optional transform and regularization.
A flexible parameter container that supports:
Bijective transformations for constrained optimization
Regularization (L1, L2, Gaussian, etc.)
Trainable or fixed parameter modes
Automatic caching of transformed values for performance
- Parameters:
value (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Initial parameter value in the constrained space.t (
Transform) – Bijective transformation to apply. Default isIdentityT().reg (
Regularization|None) – Regularization to apply. Default isNone.fit (
bool) – Whether the parameter is trainable. Default isTrue.enable_cache_logging (
bool) – Whether to enable INFO-level logging for cache events. Default isFalse. Logs cache hits, misses, invalidations, and errors for debugging.
- reg#
The regularization, if any.
- Type:
Regularization or None
- precompute#
Optional precompute function applied after transformation.
- Type:
Callable or None
- val#
The internal parameter storage.
- Type:
array_like or ParamState
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import Param, SoftplusT, L2Reg >>> # Trainable positive parameter with L2 regularization >>> param = Param( ... jnp.array([1.0, 2.0]), ... t=SoftplusT(0.0), ... reg=L2Reg(weight=0.01) ... ) >>> param.value() # Get constrained value >>> param.reg_loss() # Get regularization loss
>>> # Caching is automatic for all parameters >>> param = Param( ... jnp.array([1.0, 2.0]), ... t=SoftplusT() ... ) >>> val1 = param.value() # Computes and caches >>> val2 = param.value() # Returns cached value (fast) >>> param.set_value(jnp.array([3.0, 4.0])) # Invalidates cache >>> val3 = param.value() # Recomputes and caches
Notes
The internal value is stored in the unconstrained space when a transform is provided. The
value()method returns the constrained value after applying the forward transformation.Caching behavior: The transformed value is cached on first access and automatically invalidated when the parameter is updated (via
set_value()or direct state writes). Useclear_cache()for manual invalidation. The caching mechanism is thread-safe using RLock.- cache()[source]#
Manually cache the transformed value.
This method forces immediate computation and caching of the transformed value, even if the cache is already valid. Useful for warming up the cache before performance-critical sections.
Note
The cache is automatically populated on first access to
value(). This method is only needed for explicit cache warming.Example
>>> import jax.numpy as jnp >>> from brainstate.nn import Param, SoftplusT >>> param = Param(jnp.array([1.0, 2.0]), t=SoftplusT()) >>> param.cache() # Warm up cache before performance-critical code >>> val = param.value() # Fast - returns cached value
- property cache_stats: dict#
Get cache statistics (for debugging/monitoring).
- Returns:
Dictionary with keys:
valid,has_cached_value- Return type:
Example
>>> import jax.numpy as jnp >>> from brainstate.nn import Param, SoftplusT >>> param = Param(jnp.array([1.0]), t=SoftplusT()) >>> param.cache_stats {'valid': False, 'has_cached_value': False} >>> _ = param.value() # Compute and cache >>> param.cache_stats {'valid': True, 'has_cached_value': True}
- clear_cache()[source]#
Explicitly clear the parameter transformation cache.
This method invalidates any cached transformed value, forcing the next call to
value()to recompute the transformation. Thread-safe.Note
Cache is automatically invalidated when the parameter is updated. This method is primarily useful for manual cache management or debugging.
Example
>>> import jax.numpy as jnp >>> from brainstate.nn import Param, SoftplusT >>> param = Param(jnp.array([1.0, 2.0]), t=SoftplusT()) >>> _ = param.value() # Computes and caches >>> param.clear_cache() # Manual invalidation >>> _ = param.value() # Recomputes
- Return type:
- classmethod init(data, sizes=None, allow_none=True, **param_kwargs)[source]#
Initialize parameters.
- Parameters:
data (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|Param) –The initialization of the parameter.
If it is None, the created parameter will be None.
If it is a callable function \(f\), the
f(size)will be returned.If it is an instance of
init.Initializer`, thef(size)will be returned.If it is a tensor, then this function check whether
tensor.shapeis equal to the givensize.
allow_none (
bool) – Whether allow the parameter is None.**param_kwargs – Additional keyword arguments passed to the initialization.
- Return type:
- reset_to_prior()[source]#
Reset parameter value to regularization prior value.
Only has effect if regularization is defined.
- set_value(value)[source]#
Set parameter value from constrained space.
The value is transformed to unconstrained space for internal storage. Automatically invalidates cache.