Quickstart#
This page builds and trains a small model end to end. It assumes BrainState is already
installed (see Installation) and takes about five minutes. The goal is not to
explain every idea — the Core track does that — but to show the
shape of a BrainState program: wrap mutable arrays in State, build a Module, and train it with
state-aware transformations.
import jax
import jax.numpy as jnp
import brainstate
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
A model built from state#
A model is a brainstate.nn.Module whose trainable arrays are ParamState objects. Here is a
two-layer perceptron for regression. The parameters live as attributes; BrainState discovers them
automatically.
class MLP(brainstate.nn.Module):
def __init__(self, din, dhidden, dout):
super().__init__()
self.hidden = brainstate.nn.Linear(din, dhidden)
self.out = brainstate.nn.Linear(dhidden, dout)
def __call__(self, x):
return self.out(jnp.tanh(self.hidden(x)))
model = MLP(4, 32, 1)
model
MLP(
hidden=Linear(
in_size=(4,),
out_size=(32,),
weight=ParamState(
value={
'bias': ShapedArray(float32[32]),
'weight': ShapedArray(float32[4,32])
}
)
),
out=Linear(
in_size=(32,),
out_size=(1,),
weight=ParamState(
value={
'bias': ShapedArray(float32[1]),
'weight': ShapedArray(float32[32,1])
}
)
)
)
model.states(brainstate.ParamState) returns every trainable parameter, keyed by its path in the
module tree. This collection is what we will differentiate and update.
params = model.states(brainstate.ParamState)
list(params.keys())
[('hidden', 'weight'), ('out', 'weight')]
A forward pass#
Calling the model runs the forward computation on a batch of inputs.
x = brainstate.random.randn(128, 4)
y = jnp.sum(x ** 2, axis=-1, keepdims=True) # a simple nonlinear target
model(x).shape
(128, 1)
Gradients with respect to parameters#
brainstate.transform.grad differentiates a function with respect to a collection of states —
the parameters — rather than its positional arguments. It returns a dictionary of gradients keyed
the same way as params. Pass return_value=True to get the loss alongside the gradients.
def loss_fn():
return jnp.mean((model(x) - y) ** 2)
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
print('initial loss:', float(loss))
print('gradient keys match params:', set(grads) == set(params))
initial loss: 23.490819931030273
gradient keys match params: True
A training loop#
A training step computes gradients and applies a gradient-descent update in place. Each parameter
value can be a small PyTree (a Linear holds its weight and bias together), so the update walks it
with jax.tree.map. Wrapping the step in brainstate.transform.jit compiles it; because the
transform is state-aware, the parameter updates persist across calls automatically.
@brainstate.transform.jit
def train_step():
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
for key in params:
params[key].value = jax.tree.map(lambda p, g: p - 0.05 * g, params[key].value, grads[key])
return loss
for step in range(201):
loss = train_step()
if step % 50 == 0:
print(f'step {step:>3} loss {float(loss):.4f}')
step 0 loss 23.4908
step 50 loss 4.9425
step 100 loss 2.3319
step 150 loss 1.1466
step 200 loss 0.6898
The loss falls steadily. For real training you would reach for an optimizer such as
braintools.optim.Adam rather than hand-written gradient descent, but the structure is
identical: differentiate with respect to the parameter states, then update them.
Where to go next#
Thinking in BrainState — the mental model behind what you just wrote.
Core tutorials —
State, modules, transformations, and training in depth.Why state-based? — the design rationale.