MaxNormReg#
- class brainstate.nn.MaxNormReg(weight=1.0, max_value=1.0, fit_hyper=False)#
Max Norm regularization (soft constraint).
Implements a soft constraint on the L2 norm of parameters:
\[L = \lambda \cdot \max(0, \|x\|_2 - c)^2\]where c is the maximum allowed norm.
- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import MaxNormReg >>> reg = MaxNormReg(weight=1.0, max_value=3.0) >>> value = jnp.array([2.0, 2.0, 2.0]) # norm = sqrt(12) > 3 >>> loss = reg.loss(value) # penalty applied
Notes
Max Norm regularization is useful for constraining the capacity of neural networks without penalizing small weights.