Entropy regularization.
Regularizes based on the entropy of softmax-normalized values:
\[L = -\lambda \sum_i p_i \log(p_i)\]
where \(p = \text{softmax}(x)\).
When maximize=True, maximizes entropy (encourages uniform distribution).
When maximize=False, minimizes entropy (encourages concentrated distribution).
- Parameters:
weight (float) – Regularization weight (lambda). Default is 1.0.
maximize (bool) – Whether to maximize entropy (True) or minimize it (False).
Default is True (maximize entropy).
fit_hyper (bool) – Whether to optimize weight as a trainable parameter.
Default is False.
Examples
>>> import jax.numpy as jnp
>>> from brainstate.nn import EntropyReg
>>> reg = EntropyReg(weight=0.01, maximize=True)
>>> value = jnp.array([1.0, 2.0, 1.0])
>>> loss = reg.loss(value)
Notes
Entropy regularization is useful in attention mechanisms and
reinforcement learning to encourage exploration.
-
loss(value)[source]
Calculate Entropy regularization loss.
- Parameters:
value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (will be softmax-normalized).
- Returns:
Entropy-based loss.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity
-
reset_value()[source]
Return zero.
- Returns:
Zero.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity
-
sample_init(shape)[source]
Sample initial values.
- Parameters:
shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.
- Returns:
Uniform or concentrated initialization based on maximize.
- Return type:
Array | ndarray | bool | number | bool | int | float | complex | Quantity