GaussianReg#

class brainstate.nn.GaussianReg(mean, std, weight=1.0, fit_hyper=False)#

Gaussian prior regularization.

Implements regularization based on the negative log-likelihood of a Gaussian distribution:

\[L = \lambda \left( \sum_i \text{precision}_i \cdot (x_i - \mu_i)^2 - \sum_i \log(\text{precision}_i) \right)\]

where precision = 1/std^2 and \(\lambda\) is the weight.

Parameters:
  • mean (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Prior mean value.

  • std (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Prior standard deviation.

  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • fit_hyper (bool) – Whether to optimize mean, precision, and weight as trainable parameters. Default is False.

mean#

Prior mean (trainable if fit_hyper=True).

Type:

array_like or ParamState

precision#

Prior precision (trainable if fit_hyper=True).

Type:

array_like or ParamState

weight#

Regularization weight (trainable if fit_hyper=True).

Type:

array_like or ParamState

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import GaussianReg
>>> reg = GaussianReg(mean=0.0, std=1.0, weight=0.01)
>>> value = jnp.array([0.5, -0.5])
>>> loss = reg.loss(value)
loss(value)[source]#

Calculate Gaussian regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Gaussian negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the prior mean.

Returns:

The mean value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from the Gaussian prior.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from N(mean, std^2).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity