Parameters, Transforms, and Regularization#
A ParamState is a bare trainable array. That is all most layers need. But some parameters
carry constraints — a time constant must be positive, a mixing weight must lie in [0, 1], a
probability vector must sum to one — and some training objectives want a penalty on the
parameters themselves.
brainstate.nn.Param is a richer container that adds two orthogonal capabilities on top of a
ParamState:
a bijective transform that maps an unconstrained array (what the optimizer sees) to a constrained value (what the model uses);
a regularization term that contributes a penalty to the loss.
This tutorial covers Param, its fixed counterpart Const, the transform catalog, and the
regularization catalog, then ties them together in a single trained model.
import jax.numpy as jnp
import brainstate
import braintools
import brainstate.nn as nn
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
From ParamState to Param#
Construct a Param from an initial value. Two attributes matter:
param.value()returns the constrained value — the number your model should use. It is a method, because it may run a transform each time it is read.param.valis the underlyingParamStateholding the unconstrained array that the optimizer updates.
With the default IdentityT transform the two coincide.
w = nn.Param(jnp.array([0.5, 1.0, 2.0]))
print('value() :', w.value())
print('val :', w.val)
print('trainable:', w.fit)
value() : [0.5 1. 2. ]
val : ParamState(
value=ShapedArray(float32[3])
)
trainable: True
Fixed values with Const#
Const is a Param that is never trained (fit=False). It is not collected when you gather
ParamStates, so optimizers and grad leave it alone — ideal for buffers, lookup tables, or
hyperparameters you want to keep inside the module tree.
class Mix(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Param(jnp.array([1.0, -1.0])) # trainable
self.scale = nn.Const(jnp.array(10.0)) # fixed
def __call__(self, x):
return self.scale.value() * (x @ self.weight.value())
m = Mix()
print('trainable ParamStates:', list(m.states(brainstate.ParamState).keys()))
print('weight.fit =', m.weight.fit, '| scale.fit =', m.scale.fit)
trainable ParamStates: [('weight', 'val')]
weight.fit = True | scale.fit = False
Only weight appears among the trainable states — scale is held constant.
Constrained parameters: Transforms#
A transform is a bijector: an invertible map between an unconstrained space (all of ℝ, where
gradient descent is well behaved) and a constrained space (positives, an interval, a simplex).
The optimizer updates the unconstrained array; value() applies the forward map to produce the
constrained value; set_value() applies the inverse to store a constrained value back.
SoftplusT(lower=L) maps ℝ to (L, ∞), so the parameter is guaranteed to stay above L no
matter what the optimizer does.
t = nn.SoftplusT(lower=0.0)
raw = jnp.array([-2.0, 0.0, 3.0])
constrained = t.forward(raw)
print('forward (R -> positive):', constrained)
print('inverse (round-trip) :', t.inverse(constrained))
forward (R -> positive): [0.126928 0.69314718 3.04858732]
inverse (round-trip) : [-1.9999996 0. 3. ]
The catalog of built-in transforms covers the constraints that arise in practice:
Transform |
Constrained space |
|---|---|
|
ℝ (no constraint) |
|
strictly greater than a lower bound |
|
a bounded interval |
|
a symmetric bounded range |
|
non-negative entries summing to one |
|
unit L2 norm |
|
monotonically increasing entries |
|
reparameterisations |
Compose several with ChainT; the transforms apply in order.
chained = nn.ChainT(nn.AffineT(scale=2.0, shift=1.0), nn.SoftplusT(lower=0.0))
chained.forward(jnp.array(0.0))
Quantity(1.3132616)
Attach a transform when constructing a Param. Here a strictly positive time constant is
learned: the stored value roams over ℝ, but value() is always positive.
tau = nn.Param(jnp.array(2.0), t=nn.SoftplusT(lower=0.1))
params = {'tau': tau.val}
opt = braintools.optim.Adam(lr=1e-1)
opt.register_trainable_weights(params)
@brainstate.transform.jit
def step(target):
def loss_fn():
return (tau.value() - target) ** 2
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
opt.update(grads)
return loss
for _ in range(100):
step(0.05) # push tau toward 0.05, below the floor of 0.1
print('constrained tau:', float(tau.value()), '(never drops below the 0.1 floor)')
constrained tau: 0.2543546259403229 (never drops below the 0.1 floor)
Penalising parameters: Regularization#
A regularization object contributes a scalar penalty derived from a parameter’s value. Read it
directly with reg.loss(value), or — once attached to a Param — with param.reg_loss(),
which applies the regularizer to that parameter’s current value.
weights = jnp.array([3.0, -4.0])
print('L1 penalty :', float(nn.L1Reg(0.1).loss(weights))) # 0.1 * (|3| + |-4|)
print('L2 penalty :', float(nn.L2Reg(0.1).loss(weights))) # 0.1 * (3^2 + 4^2)
p = nn.Param(weights, reg=nn.ElasticNetReg(l1_weight=1.0, l2_weight=1.0, alpha=0.5))
print('elastic-net via reg_loss():', float(p.reg_loss()))
L1 penalty : 0.699999988079071
L2 penalty : 2.5
elastic-net via reg_loss(): 16.0
The catalog spans classical penalties and Bayesian priors (a prior contributes its negative log-density as the penalty):
Family |
Members |
|---|---|
Sparsity / shrinkage |
|
Structure |
|
Bayesian priors |
|
Combine penalties with ChainedReg. Parameters carrying a prior can be re-drawn from it with
param.reset_to_prior().
A constrained, regularized model end-to-end#
This linear model uses an L2-penalised weight vector and a strictly positive output scale. The
training objective is the data loss plus the summed regularization penalties, collected by
walking the module tree with model.nodes(nn.Param).
class RegLinear(nn.Module):
def __init__(self, din):
super().__init__()
self.w = nn.Param(brainstate.random.randn(din) * 0.1, reg=nn.L2Reg(1e-2))
self.b = nn.Param(jnp.zeros(()))
self.scale = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=1e-3))
def __call__(self, x):
return self.scale.value() * (x @ self.w.value() + self.b.value())
def reg_penalty(self):
return sum(p.reg_loss() for p in self.nodes(nn.Param).values())
model = RegLinear(4)
x = brainstate.random.randn(128, 4)
y = x @ jnp.array([1.0, -2.0, 0.5, 3.0]) + 0.3
params = model.states(brainstate.ParamState)
opt = braintools.optim.Adam(lr=5e-2)
opt.register_trainable_weights(params)
@brainstate.transform.jit
def train_step():
def loss_fn():
data_loss = jnp.mean((model(x) - y) ** 2)
return data_loss + model.reg_penalty()
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
opt.update(grads)
return loss
for epoch in range(300):
loss = train_step()
print(f'final loss : {float(loss):.4f}')
print(f'scale > 0 : {float(model.scale.value()):.4f}')
print(f'reg penalty: {float(model.reg_penalty()):.6f}')
final loss : 0.0378
scale > 0 : 1.9410
reg penalty: 0.037739
Summary#
Paramwraps aParamStatewith an optional transform and regularizer; read the usable value withvalue()and the trainable array withval.Constis a non-trainableParam, excluded fromParamStatecollection.Transforms are bijectors that keep a parameter inside its constrained space while the optimizer works in unconstrained ℝ. Compose them with
ChainT.Regularization adds a penalty via
reg.loss(value)orparam.reg_loss(); sum penalties across a model by iteratingmodel.nodes(nn.Param).
See also#
Training and metrics — the optimizer and loss machinery used here.
Constrain and regularize parameters — a focused how-to recipe.
The parameter model — the design rationale behind
Param.