TotalVariationReg#

class brainstate.nn.TotalVariationReg(weight=1.0, order=1, fit_hyper=False)#

Total Variation regularization.

Encourages smoothness by penalizing differences between adjacent values:

\[L = \lambda \sum_i |x_{i+1} - x_i|\]

For order=2 (second derivative):

\[L = \lambda \sum_i |x_{i+2} - 2x_{i+1} + x_i|\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • order (int) – Order of the difference (1 for first derivative, 2 for second). Default is 1.

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import TotalVariationReg
>>> reg = TotalVariationReg(weight=0.01, order=1)
>>> value = jnp.array([1.0, 1.2, 1.1, 1.3, 1.2])
>>> loss = reg.loss(value)

Notes

Total Variation is commonly used in image processing to encourage piecewise constant solutions while preserving edges.

loss(value)[source]#

Calculate Total Variation regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Total Variation loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample with smooth prior (correlated Gaussian).

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Smooth sample using cumulative sum of noise.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity