Dirichlet prior regularization (for probability simplexes).
Implements regularization based on the negative log-likelihood of a
Dirichlet distribution applied to softmax-normalized values:
\[L = -\lambda \sum_i (\alpha_i - 1) \log p_i\]
where \(p = \text{softmax}(x)\).
- Parameters:
weight (float) – Regularization weight (lambda). Default is 1.0.
alpha (float) – Concentration parameter (same for all dimensions). Default is 1.0.
Values < 1 encourage sparsity, values > 1 encourage uniformity.
fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.
Examples
>>> import jax.numpy as jnp
>>> from brainstate.nn import DirichletReg
>>> reg = DirichletReg(weight=1.0, alpha=1.0) # uniform prior
>>> value = jnp.array([1.0, 2.0, 1.0]) # logits
>>> loss = reg.loss(value)
Notes
Dirichlet prior is appropriate for attention weights, mixture proportions,
and other probability simplexes. alpha=1 is uniform, alpha<1 encourages
sparsity, alpha>1 encourages uniformity.
-
loss(value)[source]
Calculate Dirichlet regularization loss.
- Parameters:
value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (logits, will be softmax-normalized).
- Returns:
Dirichlet negative log-likelihood loss.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity
-
reset_value()[source]
Return zero (gives uniform distribution under softmax).
- Returns:
Zero.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity
-
sample_init(shape)[source]
Sample logits that give Dirichlet-distributed probabilities.
- Parameters:
shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.
- Returns:
Logits corresponding to Dirichlet-sampled probabilities.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity