Quickstart#

This page builds and trains a small model end to end. It assumes BrainState is already installed (see Installation) and takes about five minutes. The goal is not to explain every idea — the Core track does that — but to show the shape of a BrainState program: wrap mutable arrays in State, build a Module, and train it with state-aware transformations.

import jax
import jax.numpy as jnp

import brainstate

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

A model built from state#

A model is a brainstate.nn.Module whose trainable arrays are ParamState objects. Here is a two-layer perceptron for regression. The parameters live as attributes; BrainState discovers them automatically.

class MLP(brainstate.nn.Module):
    def __init__(self, din, dhidden, dout):
        super().__init__()
        self.hidden = brainstate.nn.Linear(din, dhidden)
        self.out = brainstate.nn.Linear(dhidden, dout)

    def __call__(self, x):
        return self.out(jnp.tanh(self.hidden(x)))

model = MLP(4, 32, 1)
model
MLP(
  hidden=Linear(
    in_size=(4,),
    out_size=(32,),
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[32]),
        'weight': ShapedArray(float32[4,32])
      }
    )
  ),
  out=Linear(
    in_size=(32,),
    out_size=(1,),
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[1]),
        'weight': ShapedArray(float32[32,1])
      }
    )
  )
)

model.states(brainstate.ParamState) returns every trainable parameter, keyed by its path in the module tree. This collection is what we will differentiate and update.

params = model.states(brainstate.ParamState)
list(params.keys())
[('hidden', 'weight'), ('out', 'weight')]

A forward pass#

Calling the model runs the forward computation on a batch of inputs.

x = brainstate.random.randn(128, 4)
y = jnp.sum(x ** 2, axis=-1, keepdims=True)   # a simple nonlinear target

model(x).shape
(128, 1)

Gradients with respect to parameters#

brainstate.transform.grad differentiates a function with respect to a collection of states — the parameters — rather than its positional arguments. It returns a dictionary of gradients keyed the same way as params. Pass return_value=True to get the loss alongside the gradients.

def loss_fn():
    return jnp.mean((model(x) - y) ** 2)

grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
print('initial loss:', float(loss))
print('gradient keys match params:', set(grads) == set(params))
initial loss: 23.490819931030273
gradient keys match params: True

A training loop#

A training step computes gradients and applies a gradient-descent update in place. Each parameter value can be a small PyTree (a Linear holds its weight and bias together), so the update walks it with jax.tree.map. Wrapping the step in brainstate.transform.jit compiles it; because the transform is state-aware, the parameter updates persist across calls automatically.

@brainstate.transform.jit
def train_step():
    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
    for key in params:
        params[key].value = jax.tree.map(lambda p, g: p - 0.05 * g, params[key].value, grads[key])
    return loss

for step in range(201):
    loss = train_step()
    if step % 50 == 0:
        print(f'step {step:>3}  loss {float(loss):.4f}')
step   0  loss 23.4908
step  50  loss 4.9425
step 100  loss 2.3319
step 150  loss 1.1466
step 200  loss 0.6898

The loss falls steadily. For real training you would reach for an optimizer such as braintools.optim.Adam rather than hand-written gradient descent, but the structure is identical: differentiate with respect to the parameter states, then update them.

Where to go next#