HorseshoeReg#
- class brainstate.nn.HorseshoeReg(weight=1.0, tau=1.0, fit_hyper=False)#
Horseshoe prior regularization (strong sparsity with heavy tails).
Implements an approximation to the horseshoe prior using a log-penalty formulation:
\[L = \lambda \sum_i \log\left(1 + (x_i / \tau)^2\right)\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import HorseshoeReg >>> reg = HorseshoeReg(weight=1.0, tau=0.1) >>> value = jnp.array([0.01, 0.5, 2.0]) >>> loss = reg.loss(value)
Notes
The horseshoe prior provides strong shrinkage toward zero for small coefficients while leaving large coefficients relatively unshrunk. This is useful for sparse signal recovery and variable selection.