GaussianReg#
- class brainstate.nn.GaussianReg(mean, std, weight=1.0, fit_hyper=False)#
Gaussian prior regularization.
Implements regularization based on the negative log-likelihood of a Gaussian distribution:
\[L = \lambda \left( \sum_i \text{precision}_i \cdot (x_i - \mu_i)^2 - \sum_i \log(\text{precision}_i) \right)\]where precision = 1/std^2 and \(\lambda\) is the weight.
- Parameters:
mean (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Prior mean value.std (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Prior standard deviation.weight (
float) – Regularization weight (lambda). Default is 1.0.fit_hyper (
bool) – Whether to optimize mean, precision, and weight as trainable parameters. Default isFalse.
- mean#
Prior mean (trainable if
fit_hyper=True).- Type:
array_like or ParamState
- precision#
Prior precision (trainable if
fit_hyper=True).- Type:
array_like or ParamState
- weight#
Regularization weight (trainable if
fit_hyper=True).- Type:
array_like or ParamState
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import GaussianReg >>> reg = GaussianReg(mean=0.0, std=1.0, weight=0.01) >>> value = jnp.array([0.5, -0.5]) >>> loss = reg.loss(value)