Batch and Accelerate#
Goal: make brainmass runs fast — compile with jit, run an ensemble or a
parameter grid in one shot with vmap, and choose the right loop primitive.
brainmass is built on JAX, so the same code runs on CPU, GPU, or TPU. The
speed-ups all come from three brainstate.transform primitives. This recipe
shows when to reach for each, with a before/after timing.
Note
Never drive a model with a bare Python for/while loop when it runs
repeatedly. A Python loop executes op-by-op (dispatch overhead, no fusion) and
re-traces the body every step; the brainstate.transform primitives lower the
whole loop into one compiled XLA program, tracing the body only once. This guide
uses the raw primitives directly to teach them — in normal use,
brainmass.Simulator wraps for_loop for you.
Which primitive?#
Shape of the work |
Primitive |
|---|---|
Single step / one-shot call |
|
Many steps, collect outputs |
|
Many steps with an explicit carry |
|
Batch over inputs / parameters |
|
Long rollout under autograd (BPTT) |
|
brainmass.Simulator already composes jit + for_loop internally, so most of
the time you just call Simulator(model, dt).run(...). Reach for the raw
primitives when you need a custom step or an explicit carry.
jit: compile a step once#
A bare Python loop calls the model op-by-op and re-traces every iteration.
Wrapping the step in brainstate.transform.jit compiles it once; subsequent
calls reuse the compiled program. Here we time a hand-written rollout with and
without jit.
node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)
node.init_all_states()
def step_uncompiled(i):
with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):
node.update()
return node.x.value
step_jit = brainstate.transform.jit(step_uncompiled)
n = 300
# warm up the compiled version (first call traces + compiles)
_ = step_jit(0); jax.block_until_ready(node.x.value)
# uncompiled: every call dispatches op-by-op
node.init_all_states()
t0 = time.perf_counter()
for i in range(n):
step_uncompiled(i)
jax.block_until_ready(node.x.value)
t_uncompiled = time.perf_counter() - t0
# compiled: each call reuses the compiled program
node.init_all_states()
t0 = time.perf_counter()
for i in range(n):
step_jit(i)
jax.block_until_ready(node.x.value)
t_jit = time.perf_counter() - t0
print(f"uncompiled : {t_uncompiled * 1e3:7.1f} ms")
print(f"jit : {t_jit * 1e3:7.1f} ms")
print(f"speed-up : {t_uncompiled / t_jit:6.1f}x")
uncompiled : 4357.3 ms
jit : 91.4 ms
speed-up : 47.7x
for_loop: fuse the whole rollout#
Even a jit-ed step is still dispatched once per iteration from Python.
brainstate.transform.for_loop lowers the entire loop into one XLA program —
it traces the body once and stacks the per-step outputs for you. State (the
model’s hidden variables) is carried automatically. This is exactly what
Simulator.run does under the hood.
node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)
node.init_all_states()
def step(i):
with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):
node.update()
return node.x.value
run = brainstate.transform.jit(lambda: brainstate.transform.for_loop(step, jnp.arange(n)))
xs = run(); jax.block_until_ready(xs) # warm up
t0 = time.perf_counter()
xs = run(); jax.block_until_ready(xs)
t_forloop = time.perf_counter() - t0
print("for_loop output:", xs.shape, "(time, regions)")
print(f"for_loop : {t_forloop * 1e3:7.1f} ms ({t_uncompiled / t_forloop:.0f}x vs uncompiled)")
for_loop output: (300, 64) (time, regions)
for_loop : 79.6 ms (55x vs uncompiled)
In practice you never write that loop — brainmass.Simulator is the same
jit + for_loop, validated and with monitors / transient / units handled.
The timing below should match the hand-written for_loop above.
sim = brainmass.Simulator(node, dt=0.1 * u.ms)
res = sim.run(n * 0.1 * u.ms, monitors=['x']) # warm up + compile
t0 = time.perf_counter()
res = sim.run(n * 0.1 * u.ms, monitors=['x'])
jax.block_until_ready(res['x'])
print("Simulator output:", res['x'].shape)
print(f"Simulator : {(time.perf_counter() - t0) * 1e3:7.1f} ms")
Simulator output: (300, 64)
Simulator : 103.6 ms
scan: thread an explicit carry#
When you need to carry a value alongside the model’s State
(f(carry, x) -> (carry, y)), use brainstate.transform.scan. A typical use is
feeding a time-varying external drive into the model and accumulating a running
statistic. Here the carry is a running sum of the output.
node = brainmass.HopfStep(in_size=8, a=0.25, w=0.3)
node.init_all_states()
drive = 0.05 * jnp.sin(2 * jnp.pi * jnp.arange(n) / n)[:, None] # (time, 1)
def body(carry, inp):
running_sum = carry
with brainstate.environ.context(t=0. * u.ms):
node.update(inp) # external drive into x
x = node.x.value
return running_sum + x, x # (new carry, per-step output)
total, xs = brainstate.transform.scan(body, jnp.zeros(8), drive)
print("scan stacked outputs:", xs.shape)
print("carried running sum :", total.shape)
scan stacked outputs: (300, 8)
carried running sum : (8,)
vmap: batch an ensemble in one run#
brainstate.transform.vmap adds a batch axis to a whole computation. There are
two common patterns.
Batched initial conditions — pass batch_size=B to
Simulator.run; it calls init_all_states(batch_size=B) and the outputs gain a
leading batch axis. This runs B independent trajectories (e.g. a noise
ensemble) in a single compiled program.
node = brainmass.HopfStep(
in_size=4, a=0.1, w=0.3,
noise_x=brainmass.OUProcess(4, sigma=0.1, tau=10 * u.ms),
)
brainstate.random.seed(0)
res = brainmass.Simulator(node, dt=0.1 * u.ms).run(
200 * u.ms, monitors=['x'], batch_size=16,
)
print("batched trajectory:", res['x'].shape, "(time, batch, regions)")
batched trajectory: (2000, 16, 4) (time, batch, regions)
Batched parameters — vmap a function that builds the model inside
itself over a parameter array. This runs one simulation per parameter value, all
fused. (See Run Parameter Sweeps for the full grid-sweep recipe.)
a_values = jnp.linspace(0.1, 1.5, 8)
def amplitude_for(a):
node = brainmass.HopfStep(in_size=1, a=a, w=0.3,
init_x=braintools.init.Constant(0.5))
r = brainmass.Simulator(node, dt=0.1 * u.ms).run(
150 * u.ms, monitors=['x'], transient=50 * u.ms)
x = u.get_magnitude(r['x'])[:, 0]
return jnp.sqrt(jnp.mean(x ** 2)) * jnp.sqrt(2.0) # RMS amplitude
amps = brainstate.transform.vmap(amplitude_for)(a_values)
for a, amp in zip(a_values, amps):
print(f"a = {float(a):.2f} -> amplitude {float(amp):.3f}")
a = 0.10 -> amplitude 0.329
a = 0.30 -> amplitude 0.561
a = 0.50 -> amplitude 0.721
a = 0.70 -> amplitude 0.849
a = 0.90 -> amplitude 0.958
a = 1.10 -> amplitude 1.055
a = 1.30 -> amplitude 1.142
a = 1.50 -> amplitude 1.223
GPU / TPU#
The code above is device-agnostic — JAX dispatches to whatever backend jaxlib
was built for. To use an accelerator:
GPU: install a CUDA build of
jaxlib(pip install brainmassthenpip install -U "jax[cuda12]"). No code change.TPU: install
jax[tpu]. No code change.
Check the active backend at runtime:
print("JAX default backend:", jax.default_backend())
print("devices:", jax.devices())
JAX default backend: cpu
devices: [CpuDevice(id=0)]
Two practical tips for accelerators:
vmapis how you fill a GPU. A single small network barely uses a GPU; avmap-ed ensemble or parameter grid keeps it busy and amortises the launch cost.block_until_readyfor honest timing. JAX is asynchronous — without it you time the dispatch, not the compute (as we did above).
Long rollouts under autograd#
If you backpropagate through a long simulation (backprop-through-time), storing
every step’s activations can exhaust memory. Swap for_loop / scan for
brainstate.transform.checkpointed_for_loop / checkpointed_scan: same
semantics, but activations are rematerialised on the backward pass (tune base)
to bound peak memory at the cost of recomputation. Reach for these only when a
reverse-mode gradient through a long run would otherwise run out of memory —
otherwise prefer plain for_loop / scan.
Next steps#
Run Parameter Sweeps — sweep a parameter grid with
vmap.Fitting with Gradients — gradients through a
Simulator.Concepts — how the run loop and transforms fit together.