SpikeAndSlabReg#
- class brainstate.nn.SpikeAndSlabReg(weight=1.0, spike_scale=0.01, slab_scale=1.0, pi=0.5, fit_hyper=False)#
Spike-and-slab prior regularization (variable selection).
Implements a soft approximation to the spike-and-slab mixture prior:
\[L = -\lambda \sum_i \log\left(\pi \cdot \text{spike}(x_i) + (1-\pi) \cdot \text{slab}(x_i)\right)\]where spike is a narrow Gaussian and slab is a wide Gaussian.
- Parameters:
weight (
float) – Regularization weight (lambda). Default is 1.0.spike_scale (
float) – Scale of the spike (narrow) component. Default is 0.01.slab_scale (
float) – Scale of the slab (wide) component. Default is 1.0.pi (
float) – Mixture weight (probability of spike). Default is 0.5.fit_hyper (
bool) – Whether to optimize hyperparameters. Default isFalse.
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import SpikeAndSlabReg >>> reg = SpikeAndSlabReg(weight=1.0, spike_scale=0.01, slab_scale=1.0, pi=0.5) >>> value = jnp.array([0.001, 0.5, -0.002]) >>> loss = reg.loss(value)
Notes
Spike-and-slab priors are the gold standard for sparse Bayesian learning and variable selection. The spike component encourages exact sparsity while the slab allows for large coefficients.