GroupLassoReg#
- class brainstate.nn.GroupLassoReg(weight=1.0, group_size=1, fit_hyper=False)#
Group Lasso regularization.
Implements Group Lasso which encourages entire groups of parameters to be zero together:
\[L = \lambda \sum_g \sqrt{\sum_{i \in g} x_i^2}\]where g indexes groups of parameters.
- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import GroupLassoReg >>> reg = GroupLassoReg(weight=0.01, group_size=4) >>> value = jnp.array([1.0, 0.5, -0.5, 0.2, 0.0, 0.0, 0.0, 0.0]) >>> loss = reg.loss(value)
Notes
Group Lasso is useful when parameters naturally form groups (e.g., all weights connecting to one neuron) and you want entire groups to be zeroed out together.