How to train through long rollouts without exhausting memory

How to train through long rollouts without exhausting memory#

Task. Backpropagate through many time steps (BPTT) without the activation memory growing linearly with the sequence length, using gradient checkpointing.

Audience. Training. Assumes Tutorial 4 · Train a spiking network.

Backprop through time stores every step’s activations for the backward pass, so memory grows with the rollout length T. For long sequences this becomes the binding constraint. brainstate.transform.checkpointed_for_loop (and checkpointed_scan) rematerialize activations on the backward pass instead of storing them all, trading extra recomputation for bounded memory. The base argument tunes the checkpoint granularity (roughly, memory scales with base + T/base).

Crucially, it is a drop-in replacement for for_loop inside the loss: same signature, same result — only the memory/compute profile differs.

import brainpy
import brainstate
import braintools
import brainunit as u
import matplotlib.pyplot as plt
import numpy as np
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
class SNN(brainstate.nn.Module):
    def __init__(self, n_in, n_rec, n_out):
        super().__init__()
        self.i2r = brainstate.nn.Sequential(
            brainstate.nn.Linear(
                n_in, n_rec,
                w_init=braintools.init.KaimingNormal(unit=u.mA),
                b_init=braintools.init.ZeroInit(unit=u.mA)),
            brainpy.state.Expon(n_rec, tau=5. * u.ms,
                                g_initializer=braintools.init.Constant(0. * u.mA)))
        self.r = brainpy.state.LIF(
            n_rec, tau=20. * u.ms, V_rest=0. * u.mV, V_reset=0. * u.mV,
            V_th=1. * u.mV, spk_fun=braintools.surrogate.ReluGrad())
        self.r2o = brainstate.nn.Linear(
            n_rec, n_out, w_init=braintools.init.KaimingNormal())
        self.o = brainpy.state.Expon(
            n_out, tau=10. * u.ms, g_initializer=braintools.init.Constant(0.))

    def update(self, spike):
        return self.o(self.r2o(self.r(self.i2r(spike))))

Swap for_loop for checkpointed_for_loop#

The only change versus an ordinary BPTT loss is the loop function. We expose a use_checkpoint flag to make the difference explicit, and a base to control the checkpoint spacing.

def make_train_step(net, x_data, y_data, num_sample, use_checkpoint, base=16):
    optimizer = braintools.optim.Adam(lr=2e-3)
    optimizer.register_trainable_weights(net.states(brainstate.ParamState))

    def loss_fn():
        if use_checkpoint:
            preds = brainstate.transform.checkpointed_for_loop(
                net.update, x_data, base=base)
        else:
            preds = brainstate.transform.for_loop(net.update, x_data)
        preds = u.math.mean(preds, axis=0)
        return braintools.metric.softmax_cross_entropy_with_integer_labels(
            preds, y_data).mean()

    @brainstate.transform.jit
    def train_step():
        brainstate.nn.init_all_states(net, batch_size=num_sample)
        grads, l = brainstate.transform.grad(
            loss_fn, net.states(brainstate.ParamState), return_value=True)()
        optimizer.update(grads)
        return l

    return train_step

Train over a long sequence#

We use a long rollout (T = 400) so checkpointing matters. The training curve is the same whether or not we checkpoint — only the peak memory differs.

with brainstate.environ.context(dt=1.0 * u.ms):
    n_in, n_rec, n_out = 80, 8, 2
    num_step, num_sample = 400, 64
    x_data = (brainstate.random.rand(num_step, num_sample, n_in)
              < 5. * u.Hz * brainstate.environ.get_dt()).astype(float)
    y_data = u.math.asarray(brainstate.random.rand(num_sample) < 0.5, dtype=int)

    net = SNN(n_in, n_rec, n_out)
    step = make_train_step(net, x_data, y_data, num_sample,
                           use_checkpoint=True, base=16)
    losses = [float(step()) for _ in range(120)]

print('checkpointed training: first %.4f  last %.4f' % (losses[0], losses[-1]))
checkpointed training: first 0.6931  last 0.6510
plt.figure(figsize=(6, 3))
plt.plot(np.asarray(losses))
plt.xlabel('Epoch'); plt.ylabel('Training loss')
plt.title('BPTT over a 400-step rollout with gradient checkpointing')
plt.show()
../../_images/3dda034f5cb7a6268e450a63e8a03482eff1dafda631e2f0e27b0120dcc6fa02.png

When to reach for this#

  • Use plain for_loop/scan by default.

  • Switch to checkpointed_for_loop/checkpointed_scan only when reverse-mode gradients through a long simulation would otherwise exhaust memory.

  • Tune base to trade recomputation against peak memory.

See also#