InverseGammaReg#

class brainstate.nn.InverseGammaReg(weight=1.0, alpha=2.0, beta=1.0, fit_hyper=False)#

Inverse-Gamma prior regularization (for variance parameters).

Implements regularization based on the negative log-likelihood of an Inverse-Gamma distribution:

\[L = \lambda \sum_i \left((\alpha + 1) \log x_i + \frac{\beta}{x_i}\right)\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • alpha (float) – Shape parameter. Default is 2.0.

  • beta (float) – Scale parameter. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import InverseGammaReg
>>> reg = InverseGammaReg(weight=1.0, alpha=2.0, beta=1.0)
>>> value = jnp.array([0.5, 1.0, 2.0])  # positive values
>>> loss = reg.loss(value)

Notes

Inverse-Gamma is commonly used as a prior for variance parameters in Bayesian models. The mode is beta/(alpha+1).

loss(value)[source]#

Calculate Inverse-Gamma regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (should be positive).

Returns:

Inverse-Gamma negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the mode of Inverse-Gamma (beta/(alpha+1)).

Returns:

Mode value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Inverse-Gamma distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from InverseGamma(alpha, beta).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity