L2Reg#
- class brainstate.nn.L2Reg(weight=1.0, fit_hyper=False)#
L2 (Ridge) regularization.
Implements L2 regularization:
\[L = \lambda \sum_i x_i^2\]The corresponding prior is the Gaussian distribution with zero mean.
- Parameters:
- weight#
Regularization weight (trainable if
fit_hyper=True).- Type:
array_like or ParamState
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import L2Reg >>> reg = L2Reg(weight=0.01) >>> value = jnp.array([1.0, -2.0, 0.5]) >>> loss = reg.loss(value) # Returns 0.01 * (1.0 + 4.0 + 0.25)
Notes
L2 regularization encourages small parameter values and is more numerically stable than L1 regularization.