ChainedReg#

class brainstate.nn.ChainedReg(*regularizations, weight=1.0, fit_hyper=False)#

Composite regularization that chains multiple regularizations together.

Combines multiple regularization priors into a single composite regularization by summing their losses. This allows applying multiple constraints or priors simultaneously to parameters.

Parameters:
  • *regularizations (Regularization) – Variable number of regularization instances to chain together.

  • weight (float) – Overall regularization weight (lambda) that scales the combined loss. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

regularizations#

The regularizations being combined.

Type:

tuple of Regularization

weight#

Regularization weight (trainable if fit_hyper=True).

Type:

array_like or ParamState

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import ChainedReg, L1Reg, L2Reg, UniformReg
>>> # Combine L1 sparsity + L2 smoothness + bounded constraint
>>> reg = ChainedReg(
...     L1Reg(weight=0.01),
...     L2Reg(weight=0.001),
...     UniformReg(weight=1.0, lower=-1.0, upper=1.0),
...     weight=1.0
... )
>>> value = jnp.array([0.5, -0.3, 0.8])
>>> loss = reg.loss(value)

Notes

  • The loss() method returns the sum of all component regularization losses, scaled by the overall weight.

  • The sample_init() and reset_value() methods use the first regularization in the chain, as it’s typically the most interpretable prior.

  • An empty chain will return zero loss and zero for sample_init/reset_value.

  • Each regularization is stored as a submodule for proper state management.

loss(value)[source]#

Calculate combined regularization loss.

Sums the losses from all component regularizations, scaled by weight.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values to compute regularization for.

Returns:

Sum of all regularization losses, scaled by weight.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the reset value from the first regularization.

Returns:

Reset value from the first regularization, or zero if the chain is empty.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample initial value from the first regularization’s prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the parameter to initialize.

Returns:

Sampled initial value from the first regularization, or zeros if the chain is empty.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity