LogNormalReg#

class brainstate.nn.LogNormalReg(weight=1.0, mu=0.0, sigma=1.0, fit_hyper=False)#

Log-normal prior regularization (for positive parameters).

Implements regularization based on the negative log-likelihood of a log-normal distribution:

\[L = \lambda \sum_i \left(\frac{(\log x_i - \mu)^2}{2\sigma^2} + \log x_i\right)\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • mu (float) – Mean of log(x). Default is 0.0.

  • sigma (float) – Standard deviation of log(x). Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import LogNormalReg
>>> reg = LogNormalReg(weight=1.0, mu=0.0, sigma=1.0)
>>> value = jnp.array([0.5, 1.0, 2.0])  # positive values
>>> loss = reg.loss(value)

Notes

Log-normal prior is appropriate for parameters that must be positive, such as scales or variances. Values <= 0 will produce invalid results.

loss(value)[source]#

Calculate Log-normal regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (should be positive).

Returns:

Log-normal negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the median of log-normal (exp(mu)).

Returns:

exp(mu).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Log-normal distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from LogNormal(mu, sigma).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity