DirichletReg#

class brainstate.nn.DirichletReg(weight=1.0, alpha=1.0, fit_hyper=False)#

Dirichlet prior regularization (for probability simplexes).

Implements regularization based on the negative log-likelihood of a Dirichlet distribution applied to softmax-normalized values:

\[L = -\lambda \sum_i (\alpha_i - 1) \log p_i\]

where \(p = \text{softmax}(x)\).

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • alpha (float) – Concentration parameter (same for all dimensions). Default is 1.0. Values < 1 encourage sparsity, values > 1 encourage uniformity.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import DirichletReg
>>> reg = DirichletReg(weight=1.0, alpha=1.0)  # uniform prior
>>> value = jnp.array([1.0, 2.0, 1.0])  # logits
>>> loss = reg.loss(value)

Notes

Dirichlet prior is appropriate for attention weights, mixture proportions, and other probability simplexes. alpha=1 is uniform, alpha<1 encourages sparsity, alpha>1 encourages uniformity.

loss(value)[source]#

Calculate Dirichlet regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (logits, will be softmax-normalized).

Returns:

Dirichlet negative log-likelihood loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero (gives uniform distribution under softmax).

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample logits that give Dirichlet-distributed probabilities.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Logits corresponding to Dirichlet-sampled probabilities.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity