StudentTReg#

class brainstate.nn.StudentTReg(weight=1.0, df=3.0, scale=1.0, fit_hyper=False)#

Student’s t-distribution prior regularization.

Implements regularization based on the negative log-likelihood of a Student’s t-distribution, which has heavier tails than Gaussian:

\[L = \lambda \sum_i \log\left(1 + \frac{(x_i / s)^2}{\nu}\right)\]

where \(\nu\) is the degrees of freedom and \(s\) is the scale.

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • df (float) – Degrees of freedom (nu). Lower values give heavier tails. Default is 3.0.

  • scale (float) – Scale parameter. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import StudentTReg
>>> reg = StudentTReg(weight=1.0, df=3.0, scale=1.0)
>>> value = jnp.array([0.5, 2.0, -1.0])
>>> loss = reg.loss(value)

Notes

Student’s t prior is more robust to outliers than Gaussian. As df -> infinity, it approaches a Gaussian distribution. df=1 gives the Cauchy distribution.

loss(value)[source]#

Calculate Student’s t regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Student’s t negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero (the mode of symmetric Student’s t).

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Student’s t distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from Student’s t distribution.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity