OrthogonalReg#
- class brainstate.nn.OrthogonalReg(weight=1.0, fit_hyper=False)#
Orthogonal regularization.
Encourages weight matrices to be orthogonal by penalizing deviation from orthogonality:
\[L = \lambda \|W^T W - I\|_F^2\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import OrthogonalReg >>> reg = OrthogonalReg(weight=0.01) >>> W = jnp.array([[1.0, 0.1], [0.1, 1.0]]) >>> loss = reg.loss(W)
Notes
Orthogonal regularization is particularly useful for RNNs where it helps prevent vanishing/exploding gradients. Works best with 2D weight matrices.