SNN Online Learning with BrainTrace#
Train a spiking neural network using ES-D-RTRL
Introduction#
Spiking Neural Networks (SNNs) process information through discrete spike events, mimicking the communication mechanism of biological neurons. Unlike traditional artificial neural networks that operate on continuous activations, SNNs emphasize the timing and frequency of spikes, making them inherently suited for temporal data processing.
Online learning is a natural fit for SNNs because they process inputs sequentially, one time step at a time. Instead of storing the entire computation graph for backpropagation through time (BPTT), online learning algorithms update weight gradients incrementally at each time step. This eliminates the need to unroll the network over the full sequence length, resulting in constant memory usage with respect to sequence length.
In this tutorial, we will use ES-D-RTRL (Eligibility-trace Scalable Decoupled Real-Time Recurrent Learning), an efficient online learning algorithm provided by braintrace. ES-D-RTRL factorizes the eligibility trace into input and output components, achieving O(B(I+O)) memory complexity (where B is batch size, I is input dimension, and O is output dimension). This makes it highly scalable for large spiking networks.
What you will learn:
How to build an SNN model using
brainstateneurons andbraintrace.nnlayersHow to set up online learning with
braintrace.ES_D_RTRL(ES-D-RTRL)How to train the SNN on random spike data
The key differences between D-RTRL and ES-D-RTRL
1. Setup#
First, let us import the required packages. The key components are:
brainstate: provides neuron models (LIF), state management, and JAX-based transformationsbraintrace: provides online learning algorithms and ETP-aware neural network layersbraintools: provides initializers, optimizers, surrogate gradient functions, and metricssaiunit: provides physical units (ms, mV, etc.) for biologically meaningful parameters
import jax
import jax.numpy as jnp
import brainstate
import braintools
import braintrace
import saiunit as u
import brainpy.state
import numpy as np
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
2. SNN Model#
We build a simple recurrent SNN with the following architecture:
Input + Recurrent Projection: A
braintrace.nn.Linearlayer that projects the concatenation of input spikes and recurrent spikes into the hidden layer. Usingbraintrace.nn.Linear(instead of a plain matrix multiply) marks this projection for participation in online learning via ETP primitives.LIF Neuron: A Leaky Integrate-and-Fire neuron from
brainpy.state.LIF. The LIF neuron integrates its input current, fires a spike when the membrane potential exceeds a threshold, and then resets. We usebraintools.surrogate.ReluGrad()as the surrogate gradient function for differentiability.Readout: A
braintrace.nn.LeakyRateReadoutthat applies leaky integration to the recurrent spikes and produces a continuous output signal for classification. This layer is also ETP-aware.
The recurrent connectivity is achieved by concatenating the neuron’s own spike output with the external input at each time step.
class LIF_SNN(brainstate.nn.Module):
"""A simple recurrent SNN with LIF neurons for online learning."""
def __init__(self, n_in, n_rec, n_out, tau_mem=20. * u.ms, tau_out=20. * u.ms):
super().__init__()
# Input + recurrent projection (ETP-aware: participates in online learning).
# Weights are in current units so that ``I * R`` lands in mV inside the LIF
# neuron (LIF integrates ``-V + I*R``; ``mA * ohm = mV`` matches V_th below).
self.linear = braintrace.nn.Linear(
n_in + n_rec, n_rec,
w_init=braintools.init.KaimingNormal(unit=u.mA),
b_init=braintools.init.ZeroInit(unit=u.mA),
)
# LIF neuron with surrogate gradient for differentiability.
self.neuron = brainpy.state.LIF(
n_rec,
tau=tau_mem,
R=1. * u.ohm,
V_th=1. * u.mV,
V_reset=0. * u.mV,
V_rest=0. * u.mV,
spk_fun=braintools.surrogate.ReluGrad(),
spk_reset='soft',
)
# Readout layer (ETP-aware: participates in online learning).
self.readout = braintrace.nn.LeakyRateReadout(
n_rec, n_out,
tau=tau_out,
w_init=braintools.init.KaimingNormal(),
)
def update(self, spike_input):
# Concatenate input spikes with recurrent spikes.
rec_spk = self.neuron.get_spike()
x = jnp.concatenate([spike_input, rec_spk], axis=-1)
# Linear projection -> LIF neuron dynamics -> readout.
self.neuron(self.linear(x))
return self.readout(self.neuron())
Let us verify that the model can be instantiated and produce output for a single sample.
with brainstate.environ.context(dt=1. * u.ms):
model = LIF_SNN(n_in=50, n_rec=128, n_out=10)
brainstate.nn.init_all_states(model)
# Single time step with random spike input
test_input = jnp.array(np.random.binomial(1, 0.1, (50,)).astype(np.float32))
output = model(test_input)
print(f"Output shape: {output.shape}")
print(f"Output values: {output}")
Output shape: (10,)
Output values: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3. Training with ES-D-RTRL#
We set up online learning using braintrace.ES_D_RTRL (also exposed as braintrace.pp_prop and the lower-level braintrace.IODimVjpAlgorithm). The key steps are:
Wrap the model with
ES_D_RTRL, supplying thedecay_or_rankparameter:decay_or_rank=floatin (0, 1] – exponentially-smoothed trace. The value is the decay factor applied per step;0.99is a common choice. Memory cost:O(B * (I + O))per layer. Approximation: the trace is a leaky moving average of recent activity.decay_or_rank=int >= 1– low-rank trace. The integer is the rank used to factorise the trace. Memory cost:O(B * rank * (I + O)). Approximation: the trace is projected onto the toprankmodes.
Pick the decay form when you want a single hyper-parameter that you can sweep cheaply, and the rank form when you want a tunable accuracy/memory trade-off independent of any time-scale assumption.
Initialize per-sample states using
vmap_new_statesso each sample in the batch has independent hidden states and eligibility traces.Compile the graph by calling
algo.compile_graph(sample_input).Define the gradient function using
brainstate.transform.grad.
braintrace.D_RTRL is the alternative algorithm; it stores the full parameter-dimension trace (O(B * theta)) and is exact rather than approximate. Use D_RTRL when memory permits and ES_D_RTRL for larger networks.
def train_snn(n_steps=100, n_epochs=50, batch_size=32, n_in=50, n_rec=128, n_out=10, lr=1e-3):
"""Train a recurrent SNN using ES-D-RTRL online learning."""
with brainstate.environ.context(dt=1. * u.ms):
# Create model and optimizer
model = LIF_SNN(n_in, n_rec, n_out)
opt = braintools.optim.Adam(lr)
weights = model.states(brainstate.ParamState)
opt.register_trainable_weights(weights)
@brainstate.transform.jit
def train_step(inputs, targets):
# Wrap model with ES-D-RTRL (decay_or_rank=0.99 means decay factor of 0.99)
algo = braintrace.ES_D_RTRL(model, decay_or_rank=0.99)
# Initialize per-sample states (each sample in the batch gets independent states)
@brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])
def init():
brainstate.nn.init_all_states(model)
algo.compile_graph(inputs[0, 0])
init()
vmapped_algo = brainstate.nn.Vmap(algo, vmap_states='new')
def loss_fn(inp):
out = vmapped_algo(inp)
loss = braintools.metric.softmax_cross_entropy_with_integer_labels(
out, targets
).mean()
return loss, out
def scan_step(prev_grads, inp):
f_grad = brainstate.transform.grad(
loss_fn, weights, has_aux=True, return_value=True
)
cur_grads, cur_loss, out = f_grad(inp)
next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)
return next_grads, cur_loss
# Accumulate gradients over all time steps
grads = jax.tree.map(jnp.zeros_like, weights.to_dict_values())
grads, losses = brainstate.transform.scan(scan_step, grads, inputs)
# Clip gradients and update weights
grads = brainstate.functional.clip_grad_norm(grads, 1.0)
opt.update(grads)
return losses.mean()
# Training loop with random spike data
losses = []
for epoch in range(n_epochs):
# Generate random spike inputs (Bernoulli with firing probability 0.1)
inputs = np.random.binomial(1, 0.1, (n_steps, batch_size, n_in)).astype(np.float32)
targets = np.random.randint(0, n_out, batch_size)
loss = train_step(jnp.array(inputs), jnp.array(targets))
losses.append(float(loss))
if epoch % 10 == 0:
print(f"Epoch {epoch:3d}, Loss: {loss:.4f}")
return losses
Let us run the training loop. Note that the first epoch will be slower due to JAX’s JIT compilation.
losses = train_snn(n_steps=50, n_epochs=50, batch_size=32, n_in=50, n_rec=128, n_out=10, lr=1e-3)
/tmp/ipykernel_287205/2090577167.py:45: DeprecationWarning: Accessing 'clip_grad_norm' from 'brainstate.functional' is deprecated and will be removed in a future version. Use 'brainstate.nn.clip_grad_norm' instead.
grads = brainstate.functional.clip_grad_norm(grads, 1.0)
Epoch 0, Loss: 2.3026
Epoch 10, Loss: 2.3026
Epoch 20, Loss: 2.3026
Epoch 30, Loss: 2.3026
Epoch 40, Loss: 2.3026
We can visualize the training loss curve to confirm that the network is learning.
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (ES-D-RTRL Online Learning)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
/tmp/ipykernel_287205/3725002578.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
4. Key Differences: D-RTRL vs ES-D-RTRL#
BrainTrace provides two main online learning algorithms. Understanding their trade-offs helps you choose the right one for your application.
Aspect |
D-RTRL ( |
ES-D-RTRL ( |
|---|---|---|
Eligibility Trace |
Full trace per weight parameter |
Factorized into input/output components |
Memory Complexity |
O(B * theta) where theta = total parameters |
O(B * (I + O)) where I = input dim, O = output dim |
Computation |
Exact gradient computation |
Approximation via low-rank factorization |
Scalability |
Suitable for small networks |
Scales to large networks (hundreds/thousands of neurons) |
Use Cases |
Research requiring exact gradients |
Practical SNN training, large-scale networks |
BrainTrace API |
|
|
When to use which?#
D-RTRL stores the full eligibility trace for each weight, giving exact online gradients. However, this requires O(B * theta) memory, which grows linearly with the number of parameters. For a network with N hidden neurons and a recurrent weight matrix of size N x N, the trace has N^4 entries per sample. This limits D-RTRL to small networks (typically < 100 neurons).
ES-D-RTRL factorizes the eligibility trace into input and output components, reducing memory to O(B * (I + O)). The
decay_or_rankparameter controls the approximation: a float value (e.g., 0.99) sets the trace decay factor, while an integer value sets the rank of the low-rank approximation. ES-D-RTRL is the recommended choice for SNNs, where networks often have hundreds or thousands of neurons.
Both algorithms use the same braintrace.nn.Linear and braintrace.nn.LeakyRateReadout layers. Switching between them requires only changing the algorithm wrapper:
# D-RTRL (exact, high memory)
algo = braintrace.ParamDimVjpAlgorithm(model)
# ES-D-RTRL (approximate, scalable)
algo = braintrace.ES_D_RTRL(model, decay_or_rank=0.99)
5. Summary#
In this tutorial, we demonstrated how to train a spiking neural network with online learning using BrainTrace. Here are the key takeaways:
Model Construction: Use
braintrace.nn.Linearandbraintrace.nn.LeakyRateReadoutfor layers that should participate in online learning (ETP-aware). Combine them with spiking neuron models frombrainpy.state(e.g.,LIF).Online Learning Setup: Wrap the model with
braintrace.ES_D_RTRL(ES-D-RTRL), callcompile_graph()to trace the computation graph, and usebrainstate.transform.gradto compute gradients at each time step.Scalability: ES-D-RTRL achieves O(B(I+O)) memory complexity, making it practical for large spiking networks. The
decay_or_rankparameter controls the trace approximation quality.Batching: Use
brainstate.transform.vmap_new_statesandbrainstate.nn.Vmapto process multiple samples in parallel, with each sample maintaining independent hidden states and eligibility traces.
For more advanced topics, including training on real neuromorphic datasets (N-MNIST) and comparing online learning with BPTT, see the detailed tutorials: