Constrain and Regularize Parameters#
A bare brainstate.ParamState holds an unconstrained array. Many models need more: a rate that
must stay positive, a mixing weight bounded to [0, 1], a probability vector that sums to one,
or an L2 penalty pulling weights toward zero. brainstate.nn.Param adds two declarative
facilities on top of ParamState:
a constraint transform (
t=) — the optimizer updates the parameter in an unconstrained space, while.value()returns it mapped into the valid domain. Gradients flow cleanly because the mapping is a smooth bijection, not a hard clip.a regularization prior (
reg=) —.reg_loss()returns a penalty you add to the loss.
brainstate.nn.Const is the companion for values that should not be trained.
import jax.numpy as jnp
import brainunit as u
import brainstate
from brainstate import nn
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
.value() returns a brainunit.Quantity (dimensionless in these examples) that behaves like a
JAX array in arithmetic. u.get_magnitude(...) extracts the plain array when we want to print
or compare it.
A positive-only parameter#
SoftplusT(lower) maps the whole real line onto (lower, ∞). We store the parameter
unconstrained and read it back through the transform, so it is positive no matter what value the
optimizer lands on — even a large negative one.
rate = nn.Param(jnp.array(0.5), t=nn.SoftplusT(lower=0.0))
print('initial value :', float(u.get_magnitude(rate.value())))
# Simulate an aggressive optimizer step that drives the *unconstrained* value negative.
rate.val.value = jnp.array(-10.0)
print('after a large negative update:', float(u.get_magnitude(rate.value())))
initial value : 0.4999999701976776
after a large negative update: 4.5398901420412585e-05
rate.val is the underlying ParamState the optimizer mutates; rate.value() is the
constrained view the model should use in its forward pass. The constrained value stays just
above the lower bound instead of going negative.
Bounding to an interval#
SigmoidT(lower, upper) constrains a parameter to an open interval. The transform catalogue
covers the common cases — a few of the most useful:
Transform |
Domain of |
|---|---|
|
|
|
|
|
non-negative vector summing to 1 — probabilities |
|
linear rescale |
|
compose transforms |
mix = nn.Param(jnp.array(0.5), t=nn.SigmoidT(lower=0.0, upper=1.0))
for unconstrained in (0.0, 100.0, -100.0):
mix.val.value = jnp.array(unconstrained)
print(f'unconstrained {unconstrained:>7} -> value {float(u.get_magnitude(mix.value())):.4f}')
unconstrained 0.0 -> value 0.5000
unconstrained 100.0 -> value 1.0000
unconstrained -100.0 -> value 0.0000
The midpoint of the unconstrained axis maps to the middle of the interval, and large magnitudes saturate toward the bounds without ever crossing them.
A SimplexT parameter is handy for a learned categorical distribution: whatever the optimizer
does to the unconstrained values, .value() is always a valid probability vector.
probs = nn.Param(jnp.zeros(3), t=nn.SimplexT())
for unconstrained in ([0.0, 0.0, 0.0], [2.0, -1.0, 0.5]):
probs.val.value = jnp.array(unconstrained)
p = u.get_magnitude(probs.value())
print('probabilities', [round(float(x), 4) for x in p], 'sum', round(float(p.sum()), 6))
probabilities [0.5, 0.25, 0.125, 0.125] sum 1.0
probabilities [0.8808, 0.0321, 0.0542, 0.0329] sum 1.0
Adding a regularization prior#
Pass reg= to attach a penalty. .reg_loss() returns the scalar contribution for that
parameter, which you add to the data loss. The built-in choices include L1Reg (sparsity),
L2Reg (weight decay), and ElasticNetReg (a blend).
weights = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L2Reg(weight=0.1))
print('L2 penalty:', float(u.get_magnitude(weights.reg_loss()))) # 0.1 * sum(w**2)
sparse = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L1Reg(weight=0.1))
print('L1 penalty:', float(u.get_magnitude(sparse.reg_loss()))) # 0.1 * sum(|w|)
L2 penalty: 2.5
L1 penalty: 0.699999988079071
Marking a value constant with Const#
Const wraps a value that participates in the forward pass but is never trained. It is
deliberately excluded from the ParamState collection, so optimizers and grad skip it.
class Scaler(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Param(jnp.ones(3)) # trainable
self.gain = nn.Const(jnp.array(2.0)) # fixed
def __call__(self, x):
return x * self.weight.value() * self.gain.value()
model = Scaler()
trainable = model.states(brainstate.ParamState)
print('trainable parameters:', list(trainable.keys())) # gain is absent
trainable parameters: [('weight', 'val')]
Putting it together in a training step#
Constrained parameters and regularization compose with the ordinary
brainstate.transform.grad workflow. The gradient is taken with respect to the unconstrained
ParamStates, so updates can be applied directly; the constraints and penalties take care of
themselves.
class ConstrainedLinear(nn.Module):
def __init__(self, din, dout):
super().__init__()
self.w = nn.Param(brainstate.random.randn(din, dout) * 0.1, reg=nn.L2Reg(1e-3))
self.gain = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=0.0))
def __call__(self, x):
return (x @ self.w.value()) * self.gain.value()
model = ConstrainedLinear(4, 2)
params = model.states(brainstate.ParamState)
x = brainstate.random.randn(16, 4)
y = brainstate.random.randn(16, 2)
def loss_fn():
mse = jnp.mean((model(x) - y) ** 2)
penalty = model.w.reg_loss()
return mse + u.get_magnitude(penalty)
@brainstate.transform.jit
def train_step():
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
for key in params:
params[key].value -= 0.1 * grads[key]
return loss
losses = [float(train_step()) for _ in range(5)]
print('loss trajectory:', [round(v, 4) for v in losses])
print('gain stays positive:', float(u.get_magnitude(model.gain.value())) > 0)
loss trajectory: [1.2443, 1.109, 1.0126, 0.9417, 0.8884]
gain stays positive: True
Summary#
nn.Param(value, t=..., reg=...)extendsParamStatewith a constraint transform and a regularization prior..value()returns the constrained value (apply this in the forward pass);.valis the underlyingParamStatethe optimizer updates in unconstrained space.Transforms (
SoftplusT,SigmoidT,SimplexT,ChainT, …) keep a parameter in its valid domain through a smooth bijection, so gradients flow..reg_loss()returns the penalty for a regularized parameter (L1Reg,L2Reg,ElasticNetReg, …); add it to the data loss.nn.Constmarks a value as non-trainable — it is excluded from theParamStatecollection.
See also#
Observe and intercept state access with hooks — an imperative alternative for enforcing invariants on writes.
Training and metrics — the full optimization loop these parameters slot into.