BetaReg#

class brainstate.nn.BetaReg(weight=1.0, a=2.0, b=2.0, fit_hyper=False)#

Beta prior regularization (for parameters in [0, 1]).

Implements regularization based on the negative log-likelihood of a Beta distribution:

\[L = -\lambda \sum_i \left((a - 1) \log x_i + (b - 1) \log(1 - x_i)\right)\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • a (float) – First shape parameter. Default is 2.0.

  • b (float) – Second shape parameter. Default is 2.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import BetaReg
>>> reg = BetaReg(weight=1.0, a=2.0, b=2.0)
>>> value = jnp.array([0.3, 0.5, 0.7])  # values in [0, 1]
>>> loss = reg.loss(value)

Notes

Beta prior is appropriate for probability parameters. a=b=1 gives uniform distribution. The mode is (a-1)/(a+b-2) for a,b > 1.

loss(value)[source]#

Calculate Beta regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (should be in [0, 1]).

Returns:

Beta negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the mode of Beta ((a-1)/(a+b-2) for a,b > 1).

Returns:

Mode value.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Beta distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from Beta(a, b).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity