SpectralNormReg#
- class brainstate.nn.SpectralNormReg(weight=1.0, max_value=1.0, n_power_iterations=1, fit_hyper=False)#
Spectral Norm regularization.
Penalizes the spectral norm (largest singular value) of weight matrices:
\[L = \lambda \cdot \max(0, \sigma_{\max}(W) - c)^2\]where \(\sigma_{\max}\) is the largest singular value and c is the maximum allowed value.
- Parameters:
weight (
float) – Regularization weight (lambda). Default is 1.0.max_value (
float) – Maximum allowed spectral norm. Default is 1.0.n_power_iterations (
int) – Number of power iterations for estimating spectral norm. Default is 1.fit_hyper (
bool) – Whether to optimize hyperparameters. Default isFalse.
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import SpectralNormReg >>> reg = SpectralNormReg(weight=1.0, max_value=1.0) >>> W = jnp.array([[2.0, 0.0], [0.0, 0.5]]) # spectral norm = 2 >>> loss = reg.loss(W)
Notes
Spectral normalization is useful for stabilizing GAN training and controlling the Lipschitz constant of neural networks.