Batch and Accelerate#

Goal: make brainmass runs fast — compile with jit, run an ensemble or a parameter grid in one shot with vmap, and choose the right loop primitive.

brainmass is built on JAX, so the same code runs on CPU, GPU, or TPU. The speed-ups all come from three brainstate.transform primitives. This recipe shows when to reach for each, with a before/after timing.

Note

Never drive a model with a bare Python for/while loop when it runs repeatedly. A Python loop executes op-by-op (dispatch overhead, no fusion) and re-traces the body every step; the brainstate.transform primitives lower the whole loop into one compiled XLA program, tracing the body only once. This guide uses the raw primitives directly to teach them — in normal use, brainmass.Simulator wraps for_loop for you.

Which primitive?#

Shape of the work

Primitive

Single step / one-shot call

brainstate.transform.jit

Many steps, collect outputs

brainstate.transform.for_loop

Many steps with an explicit carry

brainstate.transform.scan

Batch over inputs / parameters

brainstate.transform.vmap

Long rollout under autograd (BPTT)

checkpointed_for_loop / checkpointed_scan

brainmass.Simulator already composes jit + for_loop internally, so most of the time you just call Simulator(model, dt).run(...). Reach for the raw primitives when you need a custom step or an explicit carry.

jit: compile a step once#

A bare Python loop calls the model op-by-op and re-traces every iteration. Wrapping the step in brainstate.transform.jit compiles it once; subsequent calls reuse the compiled program. Here we time a hand-written rollout with and without jit.

node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)
node.init_all_states()

def step_uncompiled(i):
    with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):
        node.update()
    return node.x.value

step_jit = brainstate.transform.jit(step_uncompiled)

n = 300
# warm up the compiled version (first call traces + compiles)
_ = step_jit(0); jax.block_until_ready(node.x.value)

# uncompiled: every call dispatches op-by-op
node.init_all_states()
t0 = time.perf_counter()
for i in range(n):
    step_uncompiled(i)
jax.block_until_ready(node.x.value)
t_uncompiled = time.perf_counter() - t0

# compiled: each call reuses the compiled program
node.init_all_states()
t0 = time.perf_counter()
for i in range(n):
    step_jit(i)
jax.block_until_ready(node.x.value)
t_jit = time.perf_counter() - t0

print(f"uncompiled : {t_uncompiled * 1e3:7.1f} ms")
print(f"jit        : {t_jit * 1e3:7.1f} ms")
print(f"speed-up   : {t_uncompiled / t_jit:6.1f}x")
uncompiled :  4357.3 ms
jit        :    91.4 ms
speed-up   :   47.7x

for_loop: fuse the whole rollout#

Even a jit-ed step is still dispatched once per iteration from Python. brainstate.transform.for_loop lowers the entire loop into one XLA program — it traces the body once and stacks the per-step outputs for you. State (the model’s hidden variables) is carried automatically. This is exactly what Simulator.run does under the hood.

node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)
node.init_all_states()

def step(i):
    with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):
        node.update()
    return node.x.value

run = brainstate.transform.jit(lambda: brainstate.transform.for_loop(step, jnp.arange(n)))
xs = run(); jax.block_until_ready(xs)  # warm up

t0 = time.perf_counter()
xs = run(); jax.block_until_ready(xs)
t_forloop = time.perf_counter() - t0

print("for_loop output:", xs.shape, "(time, regions)")
print(f"for_loop   : {t_forloop * 1e3:7.1f} ms  ({t_uncompiled / t_forloop:.0f}x vs uncompiled)")
for_loop output: (300, 64) (time, regions)
for_loop   :    79.6 ms  (55x vs uncompiled)

In practice you never write that loop — brainmass.Simulator is the same jit + for_loop, validated and with monitors / transient / units handled. The timing below should match the hand-written for_loop above.

sim = brainmass.Simulator(node, dt=0.1 * u.ms)
res = sim.run(n * 0.1 * u.ms, monitors=['x'])  # warm up + compile

t0 = time.perf_counter()
res = sim.run(n * 0.1 * u.ms, monitors=['x'])
jax.block_until_ready(res['x'])
print("Simulator output:", res['x'].shape)
print(f"Simulator  : {(time.perf_counter() - t0) * 1e3:7.1f} ms")
Simulator output: (300, 64)
Simulator  :   103.6 ms

scan: thread an explicit carry#

When you need to carry a value alongside the model’s State (f(carry, x) -> (carry, y)), use brainstate.transform.scan. A typical use is feeding a time-varying external drive into the model and accumulating a running statistic. Here the carry is a running sum of the output.

node = brainmass.HopfStep(in_size=8, a=0.25, w=0.3)
node.init_all_states()

drive = 0.05 * jnp.sin(2 * jnp.pi * jnp.arange(n) / n)[:, None]  # (time, 1)

def body(carry, inp):
    running_sum = carry
    with brainstate.environ.context(t=0. * u.ms):
        node.update(inp)          # external drive into x
    x = node.x.value
    return running_sum + x, x     # (new carry, per-step output)

total, xs = brainstate.transform.scan(body, jnp.zeros(8), drive)
print("scan stacked outputs:", xs.shape)
print("carried running sum  :", total.shape)
scan stacked outputs: (300, 8)
carried running sum  : (8,)

vmap: batch an ensemble in one run#

brainstate.transform.vmap adds a batch axis to a whole computation. There are two common patterns.

Batched initial conditions — pass batch_size=B to Simulator.run; it calls init_all_states(batch_size=B) and the outputs gain a leading batch axis. This runs B independent trajectories (e.g. a noise ensemble) in a single compiled program.

node = brainmass.HopfStep(
    in_size=4, a=0.1, w=0.3,
    noise_x=brainmass.OUProcess(4, sigma=0.1, tau=10 * u.ms),
)
brainstate.random.seed(0)
res = brainmass.Simulator(node, dt=0.1 * u.ms).run(
    200 * u.ms, monitors=['x'], batch_size=16,
)
print("batched trajectory:", res['x'].shape, "(time, batch, regions)")
batched trajectory: (2000, 16, 4) (time, batch, regions)

Batched parametersvmap a function that builds the model inside itself over a parameter array. This runs one simulation per parameter value, all fused. (See Run Parameter Sweeps for the full grid-sweep recipe.)

a_values = jnp.linspace(0.1, 1.5, 8)

def amplitude_for(a):
    node = brainmass.HopfStep(in_size=1, a=a, w=0.3,
                              init_x=braintools.init.Constant(0.5))
    r = brainmass.Simulator(node, dt=0.1 * u.ms).run(
        150 * u.ms, monitors=['x'], transient=50 * u.ms)
    x = u.get_magnitude(r['x'])[:, 0]
    return jnp.sqrt(jnp.mean(x ** 2)) * jnp.sqrt(2.0)  # RMS amplitude

amps = brainstate.transform.vmap(amplitude_for)(a_values)
for a, amp in zip(a_values, amps):
    print(f"a = {float(a):.2f}  ->  amplitude {float(amp):.3f}")
a = 0.10  ->  amplitude 0.329
a = 0.30  ->  amplitude 0.561
a = 0.50  ->  amplitude 0.721
a = 0.70  ->  amplitude 0.849
a = 0.90  ->  amplitude 0.958
a = 1.10  ->  amplitude 1.055
a = 1.30  ->  amplitude 1.142
a = 1.50  ->  amplitude 1.223

GPU / TPU#

The code above is device-agnostic — JAX dispatches to whatever backend jaxlib was built for. To use an accelerator:

  • GPU: install a CUDA build of jaxlib (pip install brainmass then pip install -U "jax[cuda12]"). No code change.

  • TPU: install jax[tpu]. No code change.

Check the active backend at runtime:

print("JAX default backend:", jax.default_backend())
print("devices:", jax.devices())
JAX default backend: cpu
devices: [CpuDevice(id=0)]

Two practical tips for accelerators:

  • vmap is how you fill a GPU. A single small network barely uses a GPU; a vmap-ed ensemble or parameter grid keeps it busy and amortises the launch cost.

  • block_until_ready for honest timing. JAX is asynchronous — without it you time the dispatch, not the compute (as we did above).

Long rollouts under autograd#

If you backpropagate through a long simulation (backprop-through-time), storing every step’s activations can exhaust memory. Swap for_loop / scan for brainstate.transform.checkpointed_for_loop / checkpointed_scan: same semantics, but activations are rematerialised on the backward pass (tune base) to bound peak memory at the cost of recomputation. Reach for these only when a reverse-mode gradient through a long run would otherwise run out of memory — otherwise prefer plain for_loop / scan.

Next steps#