GroupLassoReg#

class brainstate.nn.GroupLassoReg(weight=1.0, group_size=1, fit_hyper=False)#

Group Lasso regularization.

Implements Group Lasso which encourages entire groups of parameters to be zero together:

\[L = \lambda \sum_g \sqrt{\sum_{i \in g} x_i^2}\]

where g indexes groups of parameters.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • group_size (int) – Size of each group. Default is 1 (equivalent to L1).

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import GroupLassoReg
>>> reg = GroupLassoReg(weight=0.01, group_size=4)
>>> value = jnp.array([1.0, 0.5, -0.5, 0.2, 0.0, 0.0, 0.0, 0.0])
>>> loss = reg.loss(value)

Notes

Group Lasso is useful when parameters naturally form groups (e.g., all weights connecting to one neuron) and you want entire groups to be zeroed out together.

loss(value)[source]#

Calculate Group Lasso regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Group Lasso loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from the Group Lasso prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample (using Gaussian approximation).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity