Transformations, the Essentials#
brainstate.transform mirrors the JAX transformation API — jit, grad, vmap — but every
transform is state-aware: it tracks the State objects your model reads and writes, so you
never thread parameters and buffers through function arguments by hand.
This tutorial is the gateway. It shows the three transforms you reach for daily and how they compose. The dedicated transformations track then covers each in depth — compilation internals, advanced autodiff, batched ensembles, control flow, error checking, and debugging.
import jax
import jax.numpy as jnp
import brainstate
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
Why state-aware transformations#
JAX transformations operate on pure functions — outputs depend only on inputs, with no side
effects. A BrainState model is the opposite: it keeps mutable State, and calling it reads and
writes that state. Hand such a model to raw jax.jit and the State write is silently
discarded — the counter is recomputed from its initial value on every call and never advances:
class Counter(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.n = brainstate.State(jnp.array(0))
def __call__(self):
self.n.value += 1
return self.n.value
broken = jax.jit(Counter())
print('raw jax.jit returns:', [int(broken()) for _ in range(4)]) # 1, 1, 1, 1 - update lost
raw jax.jit returns: [1, 1, 1, 1]
brainstate.transform.jit understands State. It captures every read and write, so mutations
persist correctly across calls:
counter = Counter()
fast_counter = brainstate.transform.jit(counter)
print([int(fast_counter()) for _ in range(4)]) # 1, 2, 3, 4 — state survives
[1, 2, 3, 4]
This is the rule that makes the rest of BrainState work: wrap a model in a brainstate
transform and its state is handled for you.
jit: compile once, run fast#
jit traces a function the first time it is called, compiles it with XLA, and reuses the
compiled version afterwards. Use it on whole steps — a forward pass, a training step — not on
tiny operations. We will reuse this small linear model throughout.
class Linear(brainstate.nn.Module):
def __init__(self, din, dout):
super().__init__()
self.w = brainstate.ParamState(brainstate.random.randn(din, dout) * 0.1)
self.b = brainstate.ParamState(jnp.zeros(dout))
def __call__(self, x):
return x @ self.w.value + self.b.value
model = Linear(3, 1)
x = brainstate.random.randn(64, 3)
y = brainstate.random.randn(64, 1)
forward = brainstate.transform.jit(model)
forward(x).shape
(64, 1)
grad: differentiate with respect to states#
grad differentiates a function with respect to a collection of States — not its positional
arguments, as in plain JAX. You pass the states to differentiate, and it returns a dictionary of
gradients keyed by each state’s path in the module tree.
params = model.states(brainstate.ParamState)
def loss_fn():
return jnp.mean((model(x) - y) ** 2)
grads = brainstate.transform.grad(loss_fn, params)()
{key: g.shape for key, g in grads.items()}
{('b',): (1,), ('w',): (3, 1)}
The gradient keys match the parameter keys exactly, so applying an update is a simple loop. Pass
return_value=True to also get the loss in the same pass, or has_aux=True to return extra
diagnostics from the loss function.
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
print('loss:', float(loss))
for key in params:
params[key].value -= 0.1 * grads[key]
print('loss after one step:', float(loss_fn()))
loss: 1.0843076705932617
loss after one step: 1.023072600364685
vmap: vectorize over a batch#
vmap adds a batch dimension to a function written for a single example, turning Python-level
looping into a single vectorized call. Here predict_one is written for one input row;
vmap runs it across the whole batch at once.
def predict_one(x_row):
return jnp.tanh(model(x_row[None, :]))[0]
predict_batch = brainstate.transform.vmap(predict_one)
predict_batch(x).shape
(64, 1)
Because it is state-aware, vmap can also map over the states themselves — for example to run
an ensemble of models in parallel — via its in_states / out_states arguments. That, along
with vmap2, pmap2, and shard_map, is covered in
vectorization and
advanced batching.
Composing transformations#
Transforms compose. The common pattern is jit(grad(...)): differentiate, then compile the
whole gradient computation so it runs as one fused, fast kernel.
@brainstate.transform.jit
def train_step():
grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()
for key in params:
params[key].value -= 0.1 * grads[key]
return loss
losses = [float(train_step()) for _ in range(5)]
print('loss trajectory:', [round(v, 4) for v in losses])
loss trajectory: [1.0231, 0.9841, 0.9592, 0.9431, 0.9327]
Summary#
BrainState transforms are state-aware drop-ins for their JAX counterparts: they track
Statereads and writes so you never thread state through arguments.jitcompiles a function once and reuses it; apply it to whole steps.graddifferentiates with respect to a collection of states and returns a gradient dict keyed by state path;return_valueandhas_auxcarry extra outputs.vmapvectorizes a single-example function over a batch, and can map over states too.Transforms compose —
jit(grad(...))is the backbone of every training loop.
See also#
Training and metrics — these transforms assembled into a full training loop.
The transformations track —
jit, autodiff, vectorization, control flow, error handling, and debugging in depth.Transformation semantics — how state threading works under the hood.