Parameters, Transforms, and Regularization#

A ParamState is a bare trainable array. That is all most layers need. But some parameters carry constraints — a time constant must be positive, a mixing weight must lie in [0, 1], a probability vector must sum to one — and some training objectives want a penalty on the parameters themselves.

brainstate.nn.Param is a richer container that adds two orthogonal capabilities on top of a ParamState:

  • a bijective transform that maps an unconstrained array (what the optimizer sees) to a constrained value (what the model uses);

  • a regularization term that contributes a penalty to the loss.

This tutorial covers Param, its fixed counterpart Const, the transform catalog, and the regularization catalog, then ties them together in a single trained model.

import jax.numpy as jnp

import brainstate
import braintools
import brainstate.nn as nn

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

From ParamState to Param#

Construct a Param from an initial value. Two attributes matter:

  • param.value() returns the constrained value — the number your model should use. It is a method, because it may run a transform each time it is read.

  • param.val is the underlying ParamState holding the unconstrained array that the optimizer updates.

With the default IdentityT transform the two coincide.

w = nn.Param(jnp.array([0.5, 1.0, 2.0]))
print('value() :', w.value())
print('val     :', w.val)
print('trainable:', w.fit)
value() : [0.5 1.  2. ]
val     : ParamState(
  value=ShapedArray(float32[3])
)
trainable: True

Fixed values with Const#

Const is a Param that is never trained (fit=False). It is not collected when you gather ParamStates, so optimizers and grad leave it alone — ideal for buffers, lookup tables, or hyperparameters you want to keep inside the module tree.

class Mix(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Param(jnp.array([1.0, -1.0]))   # trainable
        self.scale = nn.Const(jnp.array(10.0))           # fixed

    def __call__(self, x):
        return self.scale.value() * (x @ self.weight.value())

m = Mix()
print('trainable ParamStates:', list(m.states(brainstate.ParamState).keys()))
print('weight.fit =', m.weight.fit, '| scale.fit =', m.scale.fit)
trainable ParamStates: [('weight', 'val')]
weight.fit = True | scale.fit = False

Only weight appears among the trainable states — scale is held constant.

Constrained parameters: Transforms#

A transform is a bijector: an invertible map between an unconstrained space (all of ℝ, where gradient descent is well behaved) and a constrained space (positives, an interval, a simplex). The optimizer updates the unconstrained array; value() applies the forward map to produce the constrained value; set_value() applies the inverse to store a constrained value back.

SoftplusT(lower=L) maps ℝ to (L, ∞), so the parameter is guaranteed to stay above L no matter what the optimizer does.

t = nn.SoftplusT(lower=0.0)
raw = jnp.array([-2.0, 0.0, 3.0])
constrained = t.forward(raw)
print('forward (R -> positive):', constrained)
print('inverse (round-trip)   :', t.inverse(constrained))
forward (R -> positive): [0.126928  0.69314718 3.04858732]
inverse (round-trip)   : [-1.9999996  0.         3.       ]

The catalog of built-in transforms covers the constraints that arise in practice:

Transform

Constrained space

IdentityT

ℝ (no constraint)

SoftplusT(lower), ExpT, PositiveT

strictly greater than a lower bound

SigmoidT(lower, upper), ClipT(lower, upper)

a bounded interval

TanhT, SoftsignT, ScaledSigmoidT

a symmetric bounded range

SimplexT

non-negative entries summing to one

UnitVectorT

unit L2 norm

OrderedT

monotonically increasing entries

AffineT(scale, shift), PowerT, MaskedT

reparameterisations

Compose several with ChainT; the transforms apply in order.

chained = nn.ChainT(nn.AffineT(scale=2.0, shift=1.0), nn.SoftplusT(lower=0.0))
chained.forward(jnp.array(0.0))
Quantity(1.3132616)

Attach a transform when constructing a Param. Here a strictly positive time constant is learned: the stored value roams over ℝ, but value() is always positive.

tau = nn.Param(jnp.array(2.0), t=nn.SoftplusT(lower=0.1))

params = {'tau': tau.val}
opt = braintools.optim.Adam(lr=1e-1)
opt.register_trainable_weights(params)

@brainstate.transform.jit
def step(target):
    def loss_fn():
        return (tau.value() - target) ** 2
    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
    opt.update(grads)
    return loss

for _ in range(100):
    step(0.05)   # push tau toward 0.05, below the floor of 0.1

print('constrained tau:', float(tau.value()), '(never drops below the 0.1 floor)')
constrained tau: 0.2543546259403229 (never drops below the 0.1 floor)

Penalising parameters: Regularization#

A regularization object contributes a scalar penalty derived from a parameter’s value. Read it directly with reg.loss(value), or — once attached to a Param — with param.reg_loss(), which applies the regularizer to that parameter’s current value.

weights = jnp.array([3.0, -4.0])
print('L1 penalty :', float(nn.L1Reg(0.1).loss(weights)))   # 0.1 * (|3| + |-4|)
print('L2 penalty :', float(nn.L2Reg(0.1).loss(weights)))   # 0.1 * (3^2 + 4^2)

p = nn.Param(weights, reg=nn.ElasticNetReg(l1_weight=1.0, l2_weight=1.0, alpha=0.5))
print('elastic-net via reg_loss():', float(p.reg_loss()))
L1 penalty : 0.699999988079071
L2 penalty : 2.5
elastic-net via reg_loss(): 16.0

The catalog spans classical penalties and Bayesian priors (a prior contributes its negative log-density as the penalty):

Family

Members

Sparsity / shrinkage

L1Reg, L2Reg, ElasticNetReg, GroupLassoReg, MaxNormReg

Structure

OrthogonalReg, SpectralNormReg, TotalVariationReg, EntropyReg

Bayesian priors

GaussianReg, LogNormalReg, StudentTReg, CauchyReg, HorseshoeReg, SpikeAndSlabReg, DirichletReg, …

Combine penalties with ChainedReg. Parameters carrying a prior can be re-drawn from it with param.reset_to_prior().

A constrained, regularized model end-to-end#

This linear model uses an L2-penalised weight vector and a strictly positive output scale. The training objective is the data loss plus the summed regularization penalties, collected by walking the module tree with model.nodes(nn.Param).

class RegLinear(nn.Module):
    def __init__(self, din):
        super().__init__()
        self.w = nn.Param(brainstate.random.randn(din) * 0.1, reg=nn.L2Reg(1e-2))
        self.b = nn.Param(jnp.zeros(()))
        self.scale = nn.Param(jnp.array(1.0), t=nn.SoftplusT(lower=1e-3))

    def __call__(self, x):
        return self.scale.value() * (x @ self.w.value() + self.b.value())

    def reg_penalty(self):
        return sum(p.reg_loss() for p in self.nodes(nn.Param).values())

model = RegLinear(4)
x = brainstate.random.randn(128, 4)
y = x @ jnp.array([1.0, -2.0, 0.5, 3.0]) + 0.3
params = model.states(brainstate.ParamState)
opt = braintools.optim.Adam(lr=5e-2)
opt.register_trainable_weights(params)

@brainstate.transform.jit
def train_step():
    def loss_fn():
        data_loss = jnp.mean((model(x) - y) ** 2)
        return data_loss + model.reg_penalty()
    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
    opt.update(grads)
    return loss

for epoch in range(300):
    loss = train_step()

print(f'final loss : {float(loss):.4f}')
print(f'scale > 0  : {float(model.scale.value()):.4f}')
print(f'reg penalty: {float(model.reg_penalty()):.6f}')
final loss : 0.0378
scale > 0  : 1.9410
reg penalty: 0.037739

Summary#

  • Param wraps a ParamState with an optional transform and regularizer; read the usable value with value() and the trainable array with val.

  • Const is a non-trainable Param, excluded from ParamState collection.

  • Transforms are bijectors that keep a parameter inside its constrained space while the optimizer works in unconstrained ℝ. Compose them with ChainT.

  • Regularization adds a penalty via reg.loss(value) or param.reg_loss(); sum penalties across a model by iterating model.nodes(nn.Param).

See also#