ElasticNetReg#

class brainstate.nn.ElasticNetReg(l1_weight=1.0, l2_weight=1.0, alpha=0.5, fit_hyper=False)#

Elastic Net regularization (combination of L1 and L2).

Implements a weighted combination of L1 and L2 regularization:

\[L = \alpha \cdot \lambda_1 \sum_i |x_i| + (1 - \alpha) \cdot \lambda_2 \sum_i x_i^2\]

where \(\alpha \in [0, 1]\) controls the mix between L1 and L2.

Parameters:
  • l1_weight (float) – Weight for L1 regularization. Default is 1.0.

  • l2_weight (float) – Weight for L2 regularization. Default is 1.0.

  • alpha (float) – Mixing ratio between L1 and L2 (0 = pure L2, 1 = pure L1). Default is 0.5.

  • fit_hyper (bool) – Whether to optimize weights as trainable parameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import ElasticNetReg
>>> reg = ElasticNetReg(l1_weight=0.01, l2_weight=0.01, alpha=0.5)
>>> value = jnp.array([1.0, -2.0, 0.5])
>>> loss = reg.loss(value)

Notes

Elastic Net combines the sparsity-inducing property of L1 with the stability of L2 regularization, making it useful when there are correlated features.

loss(value)[source]#

Calculate Elastic Net regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Combined L1 and L2 loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from a mixture prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from the mixture distribution.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity