Batching Strategies#

Online learning algorithms need to handle batched data efficiently. In braintrace, there are two main batching strategies:

  • Vmap-based batching (recommended): Compile the computation graph for a single sample, then use vmap to automatically vectorize across the batch dimension.

  • Single-sample mode: Process one sample at a time, without any batching.

The choice of strategy affects how model states are initialized and how the online learning algorithm is called.

This tutorial walks through each strategy with concrete examples and shows how to build a full training loop using vmap-based batching.

Single-Sample Mode#

For debugging or situations where batch processing is unnecessary, you can compile and run the algorithm on individual samples directly. No vmap or state replication is needed.

model2 = SimpleGRU(10, 64, 5)
brainstate.nn.init_all_states(model2)

algo2 = braintrace.D_RTRL(model2)
algo2.compile_graph(jnp.zeros(10))

# Process one sample at a time
x_single = jnp.ones(10)
out = algo2(x_single)
print("Single sample output shape:", out.shape)  # (5,)

This mode is straightforward: initialize the model, compile the graph, and call the algorithm. It is useful for step-by-step debugging or when processing a single stream of data.

Multi-Step Data#

braintrace provides SingleStepData and MultiStepData wrappers to control how the algorithm processes input along the time dimension.

  • SingleStepData: Wraps data for a single time step. The algorithm processes it as one forward pass.

  • MultiStepData: Wraps a sequence of time steps. The algorithm internally scans over all steps in the sequence.

This is useful when you want to pass an entire sequence to the algorithm and have it handle the temporal loop internally, rather than manually iterating over time steps.

# Single-step: process one time step at a time
x_single = braintrace.SingleStepData(jnp.ones(10))

# Multi-step: process a sequence
sequence = jnp.ones((20, 10))  # 20 time steps, 10 features
x_multi = braintrace.MultiStepData(sequence)

When a MultiStepData object is passed to the algorithm, it will iterate over the first axis (time steps) internally. When a SingleStepData object (or a plain array) is passed, the algorithm processes it as a single forward step.

Full Training Loop with Vmap Batching#

Below is a complete example that combines vmap-based batching with a temporal training loop. The pattern is:

  1. Initialize model states and compile the graph for a single sample.

  2. Vmap the algorithm across the batch dimension.

  3. Scan over time steps, accumulating gradients at each step.

  4. Update parameters with the accumulated gradients.

@brainstate.transform.jit
def train_step(inputs, targets):
    """inputs: (n_steps, batch_size, n_in), targets: (batch_size,)"""
    weights = model.states(brainstate.ParamState)
    algo = braintrace.D_RTRL(model)

    @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])
    def init():
        brainstate.nn.init_all_states(model)
        algo.compile_graph(inputs[0, 0])
    init()
    vmapped_algo = brainstate.nn.Vmap(algo, vmap_states='new')

    def step_fn(prev_grads, inp):
        def loss_fn(inp):
            out = vmapped_algo(inp)
            return jnp.mean((out - targets) ** 2)
        cur_grads = brainstate.transform.grad(loss_fn, weights)(inp)
        return jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads), None

    grads = jax.tree.map(jnp.zeros_like, weights.to_dict_values())
    grads, _ = brainstate.transform.scan(step_fn, grads, inputs)
    return grads
# Example usage
model = SimpleGRU(10, 64, 5)
inputs = jnp.ones((20, 16, 10))  # 20 steps, batch 16, 10 features
targets = jnp.zeros((16, 5))
grads = train_step(inputs, targets)
print("Gradient keys:", list(grads.keys()))

What happens in train_step:

  1. weights collects all ParamState objects from the model.

  2. The init function, decorated with vmap_new_states, initializes per-sample hidden states and compiles the computation graph using a single-sample input.

  3. vmapped_algo wraps the algorithm for batch-parallel execution.

  4. brainstate.transform.scan iterates over the time dimension (inputs has shape (n_steps, batch_size, n_in)). At each step, step_fn computes the loss and its gradients with respect to weights, then accumulates them.

  5. The returned grads dictionary can be passed to an optimizer (e.g., braintools.optim.Adam) for a parameter update.

Summary#

  • Vmap-based batching is recommended for most use cases. It keeps model code simple (single-sample logic) while achieving efficient batch parallelism.

  • The workflow is: compile for a single sample, then vmap across the batch.

  • SingleStepData and MultiStepData control whether the algorithm processes one time step or scans over an entire sequence internally.

  • The typical training pattern is: init states -> compile graph -> vmap -> scan over time steps -> accumulate gradients -> update parameters.