HorseshoeReg#

class brainstate.nn.HorseshoeReg(weight=1.0, tau=1.0, fit_hyper=False)#

Horseshoe prior regularization (strong sparsity with heavy tails).

Implements an approximation to the horseshoe prior using a log-penalty formulation:

\[L = \lambda \sum_i \log\left(1 + (x_i / \tau)^2\right)\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • tau (float) – Global scale parameter. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import HorseshoeReg
>>> reg = HorseshoeReg(weight=1.0, tau=0.1)
>>> value = jnp.array([0.01, 0.5, 2.0])
>>> loss = reg.loss(value)

Notes

The horseshoe prior provides strong shrinkage toward zero for small coefficients while leaving large coefficients relatively unshrunk. This is useful for sparse signal recovery and variable selection.

loss(value)[source]#

Calculate Horseshoe regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Horseshoe-like penalty.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero (the mode).

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from approximate Horseshoe prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample approximating horseshoe prior.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity