TotalVariationReg#
- class brainstate.nn.TotalVariationReg(weight=1.0, order=1, fit_hyper=False)#
Total Variation regularization.
Encourages smoothness by penalizing differences between adjacent values:
\[L = \lambda \sum_i |x_{i+1} - x_i|\]For order=2 (second derivative):
\[L = \lambda \sum_i |x_{i+2} - 2x_{i+1} + x_i|\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import TotalVariationReg >>> reg = TotalVariationReg(weight=0.01, order=1) >>> value = jnp.array([1.0, 1.2, 1.1, 1.3, 1.2]) >>> loss = reg.loss(value)
Notes
Total Variation is commonly used in image processing to encourage piecewise constant solutions while preserving edges.