ChainedReg#
- class brainstate.nn.ChainedReg(*regularizations, weight=1.0, fit_hyper=False)#
Composite regularization that chains multiple regularizations together.
Combines multiple regularization priors into a single composite regularization by summing their losses. This allows applying multiple constraints or priors simultaneously to parameters.
- Parameters:
*regularizations (Regularization) – Variable number of regularization instances to chain together.
weight (
float) – Overall regularization weight (lambda) that scales the combined loss. Default is 1.0.fit_hyper (
bool) – Whether to optimize weight as a trainable parameter. Default isFalse.
- regularizations#
The regularizations being combined.
- Type:
- weight#
Regularization weight (trainable if
fit_hyper=True).- Type:
array_like or ParamState
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import ChainedReg, L1Reg, L2Reg, UniformReg >>> # Combine L1 sparsity + L2 smoothness + bounded constraint >>> reg = ChainedReg( ... L1Reg(weight=0.01), ... L2Reg(weight=0.001), ... UniformReg(weight=1.0, lower=-1.0, upper=1.0), ... weight=1.0 ... ) >>> value = jnp.array([0.5, -0.3, 0.8]) >>> loss = reg.loss(value)
Notes
The
loss()method returns the sum of all component regularization losses, scaled by the overall weight.The
sample_init()andreset_value()methods use the first regularization in the chain, as it’s typically the most interpretable prior.An empty chain will return zero loss and zero for sample_init/reset_value.
Each regularization is stored as a submodule for proper state management.
- loss(value)[source]#
Calculate combined regularization loss.
Sums the losses from all component regularizations, scaled by weight.