L1Reg#

class brainstate.nn.L1Reg(weight=1.0, fit_hyper=False)#

L1 (Lasso) regularization.

Implements L1 regularization:

\[L = \lambda \sum_i |x_i|\]

The corresponding prior is the Laplace distribution.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

weight#

Regularization weight (trainable if fit_hyper=True).

Type:

array_like or ParamState

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import L1Reg
>>> reg = L1Reg(weight=0.01)
>>> value = jnp.array([1.0, -2.0, 0.5])
>>> loss = reg.loss(value)  # Returns 0.01 * (1.0 + 2.0 + 0.5)

Notes

L1 regularization encourages sparsity in the parameter values.

loss(value)[source]#

Calculate L1 regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

L1 loss: self.weight * sum(|value|).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero (the mode of Laplace(0, b)).

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from the Laplace prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from Laplace(0, 1/weight).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity