How to train through long rollouts without exhausting memory#
Task. Backpropagate through many time steps (BPTT) without the activation memory growing linearly with the sequence length, using gradient checkpointing.
Audience. Training. Assumes Tutorial 4 · Train a spiking network.
Backprop through time stores every step’s activations for the backward pass, so
memory grows with the rollout length T. For long sequences this becomes the
binding constraint. brainstate.transform.checkpointed_for_loop (and
checkpointed_scan) rematerialize activations on the backward pass instead of
storing them all, trading extra recomputation for bounded memory. The base
argument tunes the checkpoint granularity (roughly, memory scales with
base + T/base).
Crucially, it is a drop-in replacement for for_loop inside the loss: same
signature, same result — only the memory/compute profile differs.
import brainpy
import brainstate
import braintools
import brainunit as u
import matplotlib.pyplot as plt
import numpy as np
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
class SNN(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.i2r = brainstate.nn.Sequential(
brainstate.nn.Linear(
n_in, n_rec,
w_init=braintools.init.KaimingNormal(unit=u.mA),
b_init=braintools.init.ZeroInit(unit=u.mA)),
brainpy.state.Expon(n_rec, tau=5. * u.ms,
g_initializer=braintools.init.Constant(0. * u.mA)))
self.r = brainpy.state.LIF(
n_rec, tau=20. * u.ms, V_rest=0. * u.mV, V_reset=0. * u.mV,
V_th=1. * u.mV, spk_fun=braintools.surrogate.ReluGrad())
self.r2o = brainstate.nn.Linear(
n_rec, n_out, w_init=braintools.init.KaimingNormal())
self.o = brainpy.state.Expon(
n_out, tau=10. * u.ms, g_initializer=braintools.init.Constant(0.))
def update(self, spike):
return self.o(self.r2o(self.r(self.i2r(spike))))
Swap for_loop for checkpointed_for_loop#
The only change versus an ordinary BPTT loss is the loop function. We expose a
use_checkpoint flag to make the difference explicit, and a base to control
the checkpoint spacing.
def make_train_step(net, x_data, y_data, num_sample, use_checkpoint, base=16):
optimizer = braintools.optim.Adam(lr=2e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))
def loss_fn():
if use_checkpoint:
preds = brainstate.transform.checkpointed_for_loop(
net.update, x_data, base=base)
else:
preds = brainstate.transform.for_loop(net.update, x_data)
preds = u.math.mean(preds, axis=0)
return braintools.metric.softmax_cross_entropy_with_integer_labels(
preds, y_data).mean()
@brainstate.transform.jit
def train_step():
brainstate.nn.init_all_states(net, batch_size=num_sample)
grads, l = brainstate.transform.grad(
loss_fn, net.states(brainstate.ParamState), return_value=True)()
optimizer.update(grads)
return l
return train_step
Train over a long sequence#
We use a long rollout (T = 400) so checkpointing matters. The training curve
is the same whether or not we checkpoint — only the peak memory differs.
with brainstate.environ.context(dt=1.0 * u.ms):
n_in, n_rec, n_out = 80, 8, 2
num_step, num_sample = 400, 64
x_data = (brainstate.random.rand(num_step, num_sample, n_in)
< 5. * u.Hz * brainstate.environ.get_dt()).astype(float)
y_data = u.math.asarray(brainstate.random.rand(num_sample) < 0.5, dtype=int)
net = SNN(n_in, n_rec, n_out)
step = make_train_step(net, x_data, y_data, num_sample,
use_checkpoint=True, base=16)
losses = [float(step()) for _ in range(120)]
print('checkpointed training: first %.4f last %.4f' % (losses[0], losses[-1]))
checkpointed training: first 0.6931 last 0.6510
plt.figure(figsize=(6, 3))
plt.plot(np.asarray(losses))
plt.xlabel('Epoch'); plt.ylabel('Training loss')
plt.title('BPTT over a 400-step rollout with gradient checkpointing')
plt.show()
When to reach for this#
Use plain
for_loop/scanby default.Switch to
checkpointed_for_loop/checkpointed_scanonly when reverse-mode gradients through a long simulation would otherwise exhaust memory.Tune
baseto trade recomputation against peak memory.
See also#
The state paradigm — the transform primitives and when to use each.
Differentiability — BPTT through the transform loops.
How to use surrogate gradients — the surrogate on the hidden layer.