SpectralNormReg#

class brainstate.nn.SpectralNormReg(weight=1.0, max_value=1.0, n_power_iterations=1, fit_hyper=False)#

Spectral Norm regularization.

Penalizes the spectral norm (largest singular value) of weight matrices:

\[L = \lambda \cdot \max(0, \sigma_{\max}(W) - c)^2\]

where \(\sigma_{\max}\) is the largest singular value and c is the maximum allowed value.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • max_value (float) – Maximum allowed spectral norm. Default is 1.0.

  • n_power_iterations (int) – Number of power iterations for estimating spectral norm. Default is 1.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import SpectralNormReg
>>> reg = SpectralNormReg(weight=1.0, max_value=1.0)
>>> W = jnp.array([[2.0, 0.0], [0.0, 0.5]])  # spectral norm = 2
>>> loss = reg.loss(W)

Notes

Spectral normalization is useful for stabilizing GAN training and controlling the Lipschitz constant of neural networks.

loss(value)[source]#

Calculate Spectral Norm regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Weight matrix.

Returns:

Spectral norm penalty.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample with bounded spectral norm.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample with spectral norm approximately bounded.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity