EntropyReg#

class brainstate.nn.EntropyReg(weight=1.0, maximize=True, fit_hyper=False)#

Entropy regularization.

Regularizes based on the entropy of softmax-normalized values:

\[L = -\lambda \sum_i p_i \log(p_i)\]

where \(p = \text{softmax}(x)\).

When maximize=True, maximizes entropy (encourages uniform distribution). When maximize=False, minimizes entropy (encourages concentrated distribution).

Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • maximize (bool) – Whether to maximize entropy (True) or minimize it (False). Default is True (maximize entropy).

  • fit_hyper (bool) – Whether to optimize weight as a trainable parameter. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import EntropyReg
>>> reg = EntropyReg(weight=0.01, maximize=True)
>>> value = jnp.array([1.0, 2.0, 1.0])
>>> loss = reg.loss(value)

Notes

Entropy regularization is useful in attention mechanisms and reinforcement learning to encourage exploration.

loss(value)[source]#

Calculate Entropy regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values (will be softmax-normalized).

Returns:

Entropy-based loss.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return zero.

Returns:

Zero.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample initial values.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Uniform or concentrated initialization based on maximize.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity