BetaReg#
- class brainstate.nn.BetaReg(weight=1.0, a=2.0, b=2.0, fit_hyper=False)#
Beta prior regularization (for parameters in [0, 1]).
Implements regularization based on the negative log-likelihood of a Beta distribution:
\[L = -\lambda \sum_i \left((a - 1) \log x_i + (b - 1) \log(1 - x_i)\right)\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import BetaReg >>> reg = BetaReg(weight=1.0, a=2.0, b=2.0) >>> value = jnp.array([0.3, 0.5, 0.7]) # values in [0, 1] >>> loss = reg.loss(value)
Notes
Beta prior is appropriate for probability parameters. a=b=1 gives uniform distribution. The mode is (a-1)/(a+b-2) for a,b > 1.