LogNormalReg#
- class brainstate.nn.LogNormalReg(weight=1.0, mu=0.0, sigma=1.0, fit_hyper=False)#
Log-normal prior regularization (for positive parameters).
Implements regularization based on the negative log-likelihood of a log-normal distribution:
\[L = \lambda \sum_i \left(\frac{(\log x_i - \mu)^2}{2\sigma^2} + \log x_i\right)\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import LogNormalReg >>> reg = LogNormalReg(weight=1.0, mu=0.0, sigma=1.0) >>> value = jnp.array([0.5, 1.0, 2.0]) # positive values >>> loss = reg.loss(value)
Notes
Log-normal prior is appropriate for parameters that must be positive, such as scales or variances. Values <= 0 will produce invalid results.