Batching Strategies#
Online learning algorithms need to handle batched data efficiently. In braintrace, there are two main batching strategies:
Vmap-based batching (recommended): Compile the computation graph for a single sample, then use
vmapto automatically vectorize across the batch dimension.Single-sample mode: Process one sample at a time, without any batching.
The choice of strategy affects how model states are initialized and how the online learning algorithm is called.
This tutorial walks through each strategy with concrete examples and shows how to build a full training loop using vmap-based batching.
Vmap-Based Batching (Recommended)#
The recommended approach is to compile the online learning graph for a single sample, then leverage JAX’s vmap to parallelize across the batch. The key steps are:
Create the online learning algorithm (e.g.,
D_RTRL).Use
brainstate.transform.vmap_new_statesto initialize per-sample states and compile the graph with a single-sample input shape.Wrap the algorithm with
brainstate.nn.Vmapfor parallel execution across the batch.Call the vmapped algorithm on batched inputs.
This pattern keeps the model definition simple (single-sample logic) while gaining efficient batch parallelism automatically.
import jax
import jax.numpy as jnp
import brainstate
import braintools
import braintrace
class SimpleGRU(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.rnn = braintrace.nn.GRUCell(n_in, n_rec)
self.out = braintrace.nn.Linear(n_rec, n_out)
def update(self, x):
return self.out(self.rnn(x))
model = SimpleGRU(10, 64, 5)
batch_size = 16
# Step 1: Create the algorithm
algo = braintrace.D_RTRL(model)
# Step 2: Initialize per-sample states via vmap
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=batch_size)
def init():
brainstate.nn.init_all_states(model)
algo.compile_graph(jnp.zeros(10)) # single sample shape
init()
# Step 3: Wrap for parallel execution
algo_vmapped = brainstate.nn.Vmap(algo, vmap_states='new')
# Step 4: Run on batched input
x_batch = jnp.ones((batch_size, 10))
out = algo_vmapped(x_batch)
print("Output shape:", out.shape) # (16, 5)
How it works:
vmap_new_statesruns the initialization function once but createsaxis_sizeindependent copies of all model states. Thestate_tag='new'labels these states so they can be identified later.brainstate.nn.Vmap(algo, vmap_states='new')wraps the algorithm so that each call automatically splits the batch input across the per-sample states, runs the forward pass independently for each sample, and stacks the outputs.The model itself only ever sees single-sample inputs – all batch handling is transparent.
Single-Sample Mode#
For debugging or situations where batch processing is unnecessary, you can compile and run the algorithm on individual samples directly. No vmap or state replication is needed.
model2 = SimpleGRU(10, 64, 5)
brainstate.nn.init_all_states(model2)
algo2 = braintrace.D_RTRL(model2)
algo2.compile_graph(jnp.zeros(10))
# Process one sample at a time
x_single = jnp.ones(10)
out = algo2(x_single)
print("Single sample output shape:", out.shape) # (5,)
This mode is straightforward: initialize the model, compile the graph, and call the algorithm. It is useful for step-by-step debugging or when processing a single stream of data.
Multi-Step Data#
braintrace provides SingleStepData and MultiStepData wrappers to control how the algorithm processes input along the time dimension.
SingleStepData: Wraps data for a single time step. The algorithm processes it as one forward pass.MultiStepData: Wraps a sequence of time steps. The algorithm internally scans over all steps in the sequence.
This is useful when you want to pass an entire sequence to the algorithm and have it handle the temporal loop internally, rather than manually iterating over time steps.
# Single-step: process one time step at a time
x_single = braintrace.SingleStepData(jnp.ones(10))
# Multi-step: process a sequence
sequence = jnp.ones((20, 10)) # 20 time steps, 10 features
x_multi = braintrace.MultiStepData(sequence)
When a MultiStepData object is passed to the algorithm, it will iterate over the first axis (time steps) internally. When a SingleStepData object (or a plain array) is passed, the algorithm processes it as a single forward step.
Full Training Loop with Vmap Batching#
Below is a complete example that combines vmap-based batching with a temporal training loop. The pattern is:
Initialize model states and compile the graph for a single sample.
Vmap the algorithm across the batch dimension.
Scan over time steps, accumulating gradients at each step.
Update parameters with the accumulated gradients.
@brainstate.transform.jit
def train_step(inputs, targets):
"""inputs: (n_steps, batch_size, n_in), targets: (batch_size,)"""
weights = model.states(brainstate.ParamState)
algo = braintrace.D_RTRL(model)
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])
def init():
brainstate.nn.init_all_states(model)
algo.compile_graph(inputs[0, 0])
init()
vmapped_algo = brainstate.nn.Vmap(algo, vmap_states='new')
def step_fn(prev_grads, inp):
def loss_fn(inp):
out = vmapped_algo(inp)
return jnp.mean((out - targets) ** 2)
cur_grads = brainstate.transform.grad(loss_fn, weights)(inp)
return jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads), None
grads = jax.tree.map(jnp.zeros_like, weights.to_dict_values())
grads, _ = brainstate.transform.scan(step_fn, grads, inputs)
return grads
# Example usage
model = SimpleGRU(10, 64, 5)
inputs = jnp.ones((20, 16, 10)) # 20 steps, batch 16, 10 features
targets = jnp.zeros((16, 5))
grads = train_step(inputs, targets)
print("Gradient keys:", list(grads.keys()))
What happens in train_step:
weightscollects allParamStateobjects from the model.The
initfunction, decorated withvmap_new_states, initializes per-sample hidden states and compiles the computation graph using a single-sample input.vmapped_algowraps the algorithm for batch-parallel execution.brainstate.transform.scaniterates over the time dimension (inputshas shape(n_steps, batch_size, n_in)). At each step,step_fncomputes the loss and its gradients with respect toweights, then accumulates them.The returned
gradsdictionary can be passed to an optimizer (e.g.,braintools.optim.Adam) for a parameter update.
Summary#
Vmap-based batching is recommended for most use cases. It keeps model code simple (single-sample logic) while achieving efficient batch parallelism.
The workflow is: compile for a single sample, then vmap across the batch.
SingleStepDataandMultiStepDatacontrol whether the algorithm processes one time step or scans over an entire sequence internally.The typical training pattern is: init states -> compile graph -> vmap -> scan over time steps -> accumulate gradients -> update parameters.