MaxNormReg#

class brainstate.nn.MaxNormReg(weight=1.0, max_value=1.0, fit_hyper=False)#

Max Norm regularization (soft constraint).

Implements a soft constraint on the L2 norm of parameters:

\[L = \lambda \cdot \max(0, \|x\|_2 - c)^2\]

where c is the maximum allowed norm.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • max_value (float) – Maximum allowed norm. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import MaxNormReg
>>> reg = MaxNormReg(weight=1.0, max_value=3.0)
>>> value = jnp.array([2.0, 2.0, 2.0])  # norm = sqrt(12) > 3
>>> loss = reg.loss(value)  # penalty applied

Notes

Max Norm regularization is useful for constraining the capacity of neural networks without penalizing small weights.

loss(value)[source]#

Calculate Max Norm regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Max Norm penalty (zero if norm <= max_value).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from truncated Gaussian.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample with norm <= max_value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity