SpikeAndSlabReg#

class brainstate.nn.SpikeAndSlabReg(weight=1.0, spike_scale=0.01, slab_scale=1.0, pi=0.5, fit_hyper=False)#

Spike-and-slab prior regularization (variable selection).

Implements a soft approximation to the spike-and-slab mixture prior:

\[L = -\lambda \sum_i \log\left(\pi \cdot \text{spike}(x_i) + (1-\pi) \cdot \text{slab}(x_i)\right)\]

where spike is a narrow Gaussian and slab is a wide Gaussian.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • spike_scale (float) – Scale of the spike (narrow) component. Default is 0.01.

  • slab_scale (float) – Scale of the slab (wide) component. Default is 1.0.

  • pi (float) – Mixture weight (probability of spike). Default is 0.5.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import SpikeAndSlabReg
>>> reg = SpikeAndSlabReg(weight=1.0, spike_scale=0.01, slab_scale=1.0, pi=0.5)
>>> value = jnp.array([0.001, 0.5, -0.002])
>>> loss = reg.loss(value)

Notes

Spike-and-slab priors are the gold standard for sparse Bayesian learning and variable selection. The spike component encourages exact sparsity while the slab allows for large coefficients.

loss(value)[source]#

Calculate Spike-and-slab regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Spike-and-slab negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero (the mode of the spike).

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Spike-and-slab mixture.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from spike-and-slab mixture.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity