UniformReg#

class brainstate.nn.UniformReg(weight=1.0, lower=-1.0, upper=1.0, fit_hyper=False)#

Uniform prior regularization (soft bounded constraint).

Implements a soft constraint that encourages parameters to stay within a specified interval [lower, upper]:

\[L = \lambda \sum_i \left(\text{relu}(l - x_i)^2 + \text{relu}(x_i - u)^2\right)\]
Parameters:
  • weight (float) – Regularization weight (lambda). Default is 1.0.

  • lower (float) – Lower bound. Default is -1.0.

  • upper (float) – Upper bound. Default is 1.0.

  • fit_hyper (bool) – Whether to optimize hyperparameters. Default is False.

Examples

>>> import jax.numpy as jnp
>>> from brainstate.nn import UniformReg
>>> reg = UniformReg(weight=1.0, lower=-1.0, upper=1.0)
>>> value = jnp.array([0.5, 1.5, -0.5])  # 1.5 is out of bounds
>>> loss = reg.loss(value)

Notes

This is a soft constraint; values outside the bounds are penalized but not strictly prohibited.

loss(value)[source]#

Calculate Uniform regularization loss.

Parameters:

value (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Parameter values.

Returns:

Penalty for values outside bounds.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

reset_value()[source]#

Return the midpoint of the interval.

Returns:

(lower + upper) / 2.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

sample_init(shape)[source]#

Sample from Uniform distribution.

Parameters:

shape (int | Sequence[int] | integer | Sequence[integer]) – Shape of the sample.

Returns:

Sample from Uniform(lower, upper).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity