HuberReg#

class brainstate.nn.HuberReg(weight=1.0, delta=1.0, fit_hyper=False)#

Huber regularization (robust regularization).

Implements regularization using the Huber loss function, which behaves like L2 for small values and L1 for large values:

\[\begin{split}L = \lambda \sum_i \begin{cases} \frac{1}{2} x_i^2 & \text{if } |x_i| \leq \delta \\ \delta (|x_i| - \frac{1}{2}\delta) & \text{otherwise} \end{cases}\end{split}\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • delta (float) – Threshold for switching between L2 and L1 behavior. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import HuberReg
>>> reg = HuberReg(weight=0.01, delta=1.0)
>>> value = jnp.array([0.5, 2.0, -3.0])
>>> loss = reg.loss(value)

Notes

Huber regularization is more robust to outliers than L2 while being more stable than L1 for small values.

loss(value)[source]#

Calculate Huber regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Huber loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from the Huber prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample approximately from Huber prior (using Gaussian).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity