RNN Online Learning with BrainTrace#
Train a GRU network on the copying task using D-RTRL
This quickstart tutorial demonstrates how to train a Gated Recurrent Unit (GRU) network using online learning with braintrace. We will:
Define the copying task, a standard benchmark for testing sequential memory in RNNs.
Build a GRU model using
braintrace.nncomponents.Train the model with D-RTRL (Decoupled Real-Time Recurrent Learning), an online learning algorithm that computes approximate gradients without storing the full computation graph.
Compare the online learning approach with standard Backpropagation Through Time (BPTT).
Online learning is especially useful when:
Memory is limited and storing the full unrolled computation graph is prohibitive.
You need to update parameters on-the-fly as data arrives.
You want biologically plausible learning rules for recurrent networks.
1. Setup#
First, we import the required packages.
import jax
import jax.numpy as jnp
import brainstate
import braintools
import braintrace
import numpy as np
import matplotlib.pyplot as plt
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
2. The Copying Task#
The copying task is a classic benchmark for evaluating whether an RNN can memorize and recall information over a delay period.
How it works:
The model receives a sequence of 10 random digits (values 1-8) encoded as one-hot vectors.
This is followed by a delay period filled with zeros (the “wait” phase).
A special trigger symbol (value 9) signals the model to reproduce the original 10 digits.
Input: [3 7 1 5 2 8 4 6 1 3] [0 0 ... 0 0] [9 9 9 9 9 9 9 9 9 9]
memorize wait/delay recall trigger
Target: [3 7 1 5 2 8 4 6 1 3]
The longer the delay (time_lag), the harder the task. The model must retain information in its hidden state across the entire delay period.
class CopyDataset:
"""Data generator for the copying task.
Args:
time_lag: Number of delay steps between memorization and recall.
batch_size: Number of samples per batch.
"""
def __init__(self, time_lag: int, batch_size: int):
self.seq_length = time_lag + 20
self.batch_size = batch_size
def __iter__(self):
while True:
ids = np.zeros([self.batch_size, self.seq_length], dtype=int)
# First 10 positions: random digits 1-8
ids[..., :10] = np.random.randint(1, 9, (self.batch_size, 10))
# Last 10 positions: trigger symbol (9)
ids[..., -10:] = np.ones([self.batch_size, 10], dtype=int) * 9
# One-hot encode the input sequence
x = np.zeros([self.batch_size, self.seq_length, 10])
for i in range(self.batch_size):
x[i, range(self.seq_length), ids[i]] = 1
# Target: the original 10 digits to recall
yield x, ids[..., :10]
3. Model Definition#
We define a GRU network using braintrace.nn.GRUCell for the recurrent layer and braintrace.nn.Linear for the output layer. These modules are designed to work with braintrace’s online learning algorithms – they expose the internal structure needed for eligibility trace computation.
class GRUNet(brainstate.nn.Module):
"""A multi-layer GRU network with a linear readout.
Args:
n_in: Input feature dimension.
n_rec: Hidden state dimension.
n_out: Output dimension.
n_layer: Number of stacked GRU layers.
"""
def __init__(self, n_in, n_rec, n_out, n_layer=1):
super().__init__()
layers = []
for _ in range(n_layer):
layers.append(braintrace.nn.GRUCell(n_in, n_rec))
n_in = n_rec
self.rnn = brainstate.nn.Sequential(*layers)
self.readout = braintrace.nn.Linear(n_rec, n_out)
def update(self, x):
return self.readout(self.rnn(x))
4. Online Training with D-RTRL#
D-RTRL (Decoupled Real-Time Recurrent Learning) is an online learning algorithm provided by braintrace. Unlike BPTT, which requires storing the entire computation graph across all time steps, D-RTRL computes gradients incrementally at each time step using eligibility traces.
The key steps in the online training loop are:
Initialize states: Reset the model’s hidden states and compile the eligibility trace graph.
Warm-up phase: Run the model forward (without learning) to let the hidden states and eligibility traces stabilize.
Learning phase: At each time step, compute the gradient of the current loss with respect to the parameters, and accumulate gradients over time.
Parameter update: After processing the full sequence, apply the accumulated gradients to update the parameters.
The D_RTRL class wraps the model and handles the eligibility trace bookkeeping automatically.
def train_online(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3):
"""Train a GRU on the copying task using D-RTRL online learning.
Args:
n_epochs: Number of training iterations.
n_seq: Length of the delay period in the copying task.
lr: Learning rate for the Adam optimizer.
Returns:
List of loss values over training.
"""
model = GRUNet(10, 128, 10)
opt = braintools.optim.Adam(lr)
weights = model.states().subset(brainstate.ParamState)
opt.register_trainable_weights(weights)
@brainstate.transform.jit
def train_step(inputs, targets):
# Create the online learning algorithm wrapper
algo = braintrace.D_RTRL(model)
# Initialize hidden states and compile eligibility trace graph
# using vmap to handle the batch dimension
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])
def init():
brainstate.nn.init_all_states(model)
algo.compile_graph(inputs[0, 0])
init()
algo = brainstate.nn.Vmap(algo, vmap_states='new')
def etrace_loss(inp, tar):
out = algo(inp)
loss = braintools.metric.softmax_cross_entropy_with_integer_labels(out, tar).mean()
return loss, out
def step(prev_grads, x):
inp, tar = x
f_grad = brainstate.transform.grad(
etrace_loss, weights, has_aux=True, return_value=True
)
cur_grads, loss, out = f_grad(inp, tar)
next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)
return next_grads, loss
# Warm-up: run the model forward to stabilize hidden states
# and eligibility traces before computing learning gradients
n_sim = n_seq + 10
brainstate.transform.for_loop(lambda inp: algo(inp), inputs[:n_sim])
# Learning phase: accumulate gradients over the recall period
grads = jax.tree.map(jnp.zeros_like, {k: v.value for k, v in weights.items()})
grads, losses = brainstate.transform.scan(step, grads, (inputs[n_sim:], targets))
opt.update(grads)
return losses.mean()
# Training loop
dataloader = CopyDataset(n_seq, batch_size)
losses = []
for i, (x, y) in enumerate(dataloader):
if i >= n_epochs:
break
# Transpose from (batch, time, features) to (time, batch, features)
x = jnp.asarray(np.transpose(x, (1, 0, 2)))
y = jnp.asarray(np.transpose(y, (1, 0)))
loss = train_step(x, y)
losses.append(float(loss))
if i % 50 == 0:
print(f"Step {i}, Loss: {loss:.4f}")
return losses
5. BPTT Baseline (for Comparison)#
To appreciate the advantages of online learning, we also implement a standard BPTT trainer. BPTT unrolls the full computation graph across all time steps, computes the loss, and backpropagates through the entire sequence. This requires storing all intermediate activations, resulting in memory usage that scales linearly with sequence length.
def train_bptt(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3):
"""Train a GRU on the copying task using BPTT.
Args:
n_epochs: Number of training iterations.
n_seq: Length of the delay period in the copying task.
lr: Learning rate for the Adam optimizer.
Returns:
List of loss values over training.
"""
model = GRUNet(10, 128, 10)
opt = braintools.optim.Adam(lr)
weights = model.states().subset(brainstate.ParamState)
opt.register_trainable_weights(weights)
@brainstate.transform.jit
def train_step(inputs, targets):
# Initialize hidden states with vmap for batch dimension
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])
def init():
brainstate.nn.init_all_states(model)
init()
vmapped_model = brainstate.nn.Vmap(model, vmap_states='new')
def run_step(inp, tar):
out = vmapped_model(inp)
loss = braintools.metric.softmax_cross_entropy_with_integer_labels(out, tar).mean()
return out, loss
def bptt_forward():
# Warm-up: run forward without computing loss
n_sim = n_seq + 10
brainstate.transform.for_loop(vmapped_model, inputs[:n_sim])
# Compute loss over the recall period
outs, losses = brainstate.transform.for_loop(run_step, inputs[n_sim:], targets)
return losses.mean(), outs
# Backpropagate through time to get gradients
grads, loss, outs = brainstate.transform.grad(
bptt_forward, weights, has_aux=True, return_value=True
)()
opt.update(grads)
return loss
# Training loop
dataloader = CopyDataset(n_seq, batch_size)
losses = []
for i, (x, y) in enumerate(dataloader):
if i >= n_epochs:
break
x = jnp.asarray(np.transpose(x, (1, 0, 2)))
y = jnp.asarray(np.transpose(y, (1, 0)))
loss = train_step(x, y)
losses.append(float(loss))
if i % 50 == 0:
print(f"Step {i}, Loss: {loss:.4f}")
return losses
6. Run Training#
Let us train both the online (D-RTRL) and offline (BPTT) models. We use a delay of 50 time steps, which is a moderate difficulty for the copying task.
online_losses = train_online(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3)
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('rnn', 'layers', 0, 'Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mv). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.
_emit_no_relation_diag(
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('readout', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.
_emit_no_relation_diag(
Step 0, Loss: 2.2919
Step 50, Loss: 2.0887
Step 100, Loss: 2.0790
Step 150, Loss: 2.0802
bptt_losses = train_bptt(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3)
Step 0, Loss: 2.2838
Step 50, Loss: 2.0851
Step 100, Loss: 2.0819
Step 150, Loss: 2.0830
7. Visualization#
Plot the training loss curves to compare online learning (D-RTRL) with BPTT.
plt.figure(figsize=(8, 4))
plt.plot(online_losses, label='D-RTRL (Online)')
plt.plot(bptt_losses, label='BPTT (Offline)')
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('GRU on Copying Task: Online vs. Offline Learning')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
/tmp/ipykernel_283514/382461402.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
8. Summary#
In this tutorial, we demonstrated how to use braintrace for online learning of a GRU network on the copying task.
Key takeaways:
D-RTRL provides approximate online gradients with
O(B * theta)complexity, whereBis the batch size andthetais the number of parameters. Unlike BPTT, it does not need to store the full unrolled computation graph.The online training loop uses
braintrace.D_RTRLto automatically manage eligibility traces. You only need to:Wrap your model with
D_RTRL.Call
compile_graph()to set up the trace computation.Use standard
brainstate.transform.gradto compute per-step gradients.
Online learning works with standard JAX gradient APIs and transformations (
jit,vmap,scan).braintraceis particularly effective for RNN models with gating mechanisms (GRU, LSTM), where the internal dynamics naturally support eligibility trace propagation.
For more details, see:
Key Concepts for the theoretical background.
SNN Online Learning for applying the same approach to spiking neural networks.