Constrain and Regularize Parameters#

A bare brainstate.ParamState holds an unconstrained array. Many models need more: a rate that must stay positive, a mixing weight bounded to [0, 1], a probability vector that sums to one, or an L2 penalty pulling weights toward zero. brainstate.nn.Param adds two declarative facilities on top of ParamState:

  • a constraint transform (t=) — the optimizer updates the parameter in an unconstrained space, while .value() returns it mapped into the valid domain. Gradients flow cleanly because the mapping is a smooth bijection, not a hard clip.

  • a regularization prior (reg=) — .reg_loss() returns a penalty you add to the loss.

brainstate.nn.Const is the companion for values that should not be trained.

import jax.numpy as jnp
import brainunit as u

import brainstate
from brainstate import nn

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

.value() returns a brainunit.Quantity (dimensionless in these examples) that behaves like a JAX array in arithmetic. u.get_magnitude(...) extracts the plain array when we want to print or compare it.

A positive-only parameter#

SoftplusT(lower) maps the whole real line onto (lower, ∞). We store the parameter unconstrained and read it back through the transform, so it is positive no matter what value the optimizer lands on — even a large negative one.

rate = nn.Param(jnp.array(0.5), t=nn.SoftplusT(lower=0.0))
print('initial value :', float(u.get_magnitude(rate.value())))

# Simulate an aggressive optimizer step that drives the *unconstrained* value negative.
rate.val.value = jnp.array(-10.0)
print('after a large negative update:', float(u.get_magnitude(rate.value())))
initial value : 0.4999999701976776
after a large negative update: 4.5398901420412585e-05

rate.val is the underlying ParamState the optimizer mutates; rate.value() is the constrained view the model should use in its forward pass. The constrained value stays just above the lower bound instead of going negative.

Bounding to an interval#

SigmoidT(lower, upper) constrains a parameter to an open interval. The transform catalogue covers the common cases — a few of the most useful:

Transform

Domain of .value()

SoftplusT(lower) / ExpT(lower)

(lower, ∞) — positive quantities

SigmoidT(lower, upper)

(lower, upper) — bounded scalars

SimplexT()

non-negative vector summing to 1 — probabilities

AffineT(scale, shift)

linear rescale

ChainT(t1, t2, ...)

compose transforms

mix = nn.Param(jnp.array(0.5), t=nn.SigmoidT(lower=0.0, upper=1.0))

for unconstrained in (0.0, 100.0, -100.0):
    mix.val.value = jnp.array(unconstrained)
    print(f'unconstrained {unconstrained:>7} -> value {float(u.get_magnitude(mix.value())):.4f}')
unconstrained     0.0 -> value 0.5000
unconstrained   100.0 -> value 1.0000
unconstrained  -100.0 -> value 0.0000

The midpoint of the unconstrained axis maps to the middle of the interval, and large magnitudes saturate toward the bounds without ever crossing them.

A SimplexT parameter is handy for a learned categorical distribution: whatever the optimizer does to the unconstrained values, .value() is always a valid probability vector.

probs = nn.Param(jnp.zeros(3), t=nn.SimplexT())

for unconstrained in ([0.0, 0.0, 0.0], [2.0, -1.0, 0.5]):
    probs.val.value = jnp.array(unconstrained)
    p = u.get_magnitude(probs.value())
    print('probabilities', [round(float(x), 4) for x in p], 'sum', round(float(p.sum()), 6))
probabilities [0.5, 0.25, 0.125, 0.125] sum 1.0
probabilities [0.8808, 0.0321, 0.0542, 0.0329] sum 1.0

Adding a regularization prior#

Pass reg= to attach a penalty. .reg_loss() returns the scalar contribution for that parameter, which you add to the data loss. The built-in choices include L1Reg (sparsity), L2Reg (weight decay), and ElasticNetReg (a blend).

weights = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L2Reg(weight=0.1))
print('L2 penalty:', float(u.get_magnitude(weights.reg_loss())))   # 0.1 * sum(w**2)

sparse = nn.Param(jnp.array([3.0, -4.0]), reg=nn.L1Reg(weight=0.1))
print('L1 penalty:', float(u.get_magnitude(sparse.reg_loss())))    # 0.1 * sum(|w|)
L2 penalty: 2.5
L1 penalty: 0.699999988079071

Marking a value constant with Const#

Const wraps a value that participates in the forward pass but is never trained. It is deliberately excluded from the ParamState collection, so optimizers and grad skip it.

class Scaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Param(jnp.ones(3))      # trainable
        self.gain = nn.Const(jnp.array(2.0))     # fixed

    def __call__(self, x):
        return x * self.weight.value() * self.gain.value()

model = Scaler()
trainable = model.states(brainstate.ParamState)
print('trainable parameters:', list(trainable.keys()))   # gain is absent
trainable parameters: [('weight', 'val')]

Putting it together in a training step#

Constrained parameters and regularization compose with the ordinary brainstate.transform.grad workflow. The gradient is taken with respect to the unconstrained ParamStates, so updates can be applied directly; the constraints and penalties take care of themselves.

class ConstrainedLinear(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.w = nn.Param(brainstate.random.randn(din, dout) * 0.1, reg=nn.L2Reg(1e-3))
        self.gain = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=0.0))

    def __call__(self, x):
        return (x @ self.w.value()) * self.gain.value()

model = ConstrainedLinear(4, 2)
params = model.states(brainstate.ParamState)
x = brainstate.random.randn(16, 4)
y = brainstate.random.randn(16, 2)

def loss_fn():
    mse = jnp.mean((model(x) - y) ** 2)
    penalty = model.w.reg_loss()
    return mse + u.get_magnitude(penalty)

@brainstate.transform.jit
def train_step():
    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
    for key in params:
        params[key].value -= 0.1 * grads[key]
    return loss

losses = [float(train_step()) for _ in range(5)]
print('loss trajectory:', [round(v, 4) for v in losses])
print('gain stays positive:', float(u.get_magnitude(model.gain.value())) > 0)
loss trajectory: [1.2443, 1.109, 1.0126, 0.9417, 0.8884]
gain stays positive: True

Summary#

  • nn.Param(value, t=..., reg=...) extends ParamState with a constraint transform and a regularization prior.

  • .value() returns the constrained value (apply this in the forward pass); .val is the underlying ParamState the optimizer updates in unconstrained space.

  • Transforms (SoftplusT, SigmoidT, SimplexT, ChainT, …) keep a parameter in its valid domain through a smooth bijection, so gradients flow.

  • .reg_loss() returns the penalty for a regularized parameter (L1Reg, L2Reg, ElasticNetReg, …); add it to the data loss.

  • nn.Const marks a value as non-trainable — it is excluded from the ParamState collection.

See also#