UniformReg#
- class brainstate.nn.UniformReg(weight=1.0, lower=-1.0, upper=1.0, fit_hyper=False)#
Uniform prior regularization (soft bounded constraint).
Implements a soft constraint that encourages parameters to stay within a specified interval [lower, upper]:
\[L = \lambda \sum_i \left(\text{relu}(l - x_i)^2 + \text{relu}(x_i - u)^2\right)\]- Parameters:
Examples
>>> import jax.numpy as jnp >>> from brainstate.nn import UniformReg >>> reg = UniformReg(weight=1.0, lower=-1.0, upper=1.0) >>> value = jnp.array([0.5, 1.5, -0.5]) # 1.5 is out of bounds >>> loss = reg.loss(value)
Notes
This is a soft constraint; values outside the bounds are penalized but not strictly prohibited.