OrthogonalReg#

class brainstate.nn.OrthogonalReg(weight=1.0, fit_hyper=False)#

Orthogonal regularization.

Encourages weight matrices to be orthogonal by penalizing deviation from orthogonality:

\[L = \lambda \|W^T W - I\|_F^2\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import OrthogonalReg
>>> reg = OrthogonalReg(weight=0.01)
>>> W = jnp.array([[1.0, 0.1], [0.1, 1.0]])
>>> loss = reg.loss(W)

Notes

Orthogonal regularization is particularly useful for RNNs where it helps prevent vanishing/exploding gradients. Works best with 2D weight matrices.

loss(value)[source]#

Calculate Orthogonal regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Weight matrix (2D) or flattened parameters.

Returns:

Orthogonality penalty.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from uniform distribution on orthogonal matrices.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Orthogonal initialization using QR decomposition.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity