Core Concepts of BrainTrace#
Welcome to BrainTrace! This notebook introduces the core concepts you need to understand before using the library for online learning in recurrent and spiking neural networks.
BrainTrace is built on JAX and brainstate, providing memory-efficient online learning through eligibility trace propagation.
1. What is Online Learning?#
Training recurrent neural networks (RNNs) typically relies on Backpropagation Through Time (BPTT). BPTT unrolls the full computation graph over all time steps before computing gradients:
BPTT stores the entire computation graph across \(T\) time steps, requiring \(O(T)\) memory.
As sequence length grows, memory usage becomes a bottleneck.
Online learning takes a different approach:
Weights are updated at each time step using eligibility traces that summarize the gradient history.
Memory cost is \(O(1)\) per time step (independent of sequence length).
Eligibility traces accumulate the information needed for gradient computation incrementally.
BrainTrace implements online learning via JAX custom primitives. Instead of relying on string-matching or special parameter wrappers, BrainTrace identifies which operations participate in online learning by their primitive type at the JAX IR level. This gives a clean, composable, and JIT-friendly design.
2. Architecture Overview#
BrainTrace is organized as a 4-layer system. Each layer builds on the one below it:
+--------------------------------------------------------------+
| Algorithms braintrace.D_RTRL / braintrace.ES_D_RTRL |
| trace update + custom_vjp for jax.grad |
+--------------------------------------------------------------+
| Executor ETraceGraphExecutor |
| forward pass + Jacobian computation |
+--------------------------------------------------------------+
| Compiler compile_etrace_graph() |
| jaxpr walk -> find primitives -> connect to |
| hidden states |
+--------------------------------------------------------------+
| Primitives braintrace.matmul / element_wise / conv |
| & Functions JAX custom primitives (thin markers) |
+--------------------------------------------------------------+
How it works:
Primitives & Functions (bottom layer): You call
braintrace.matmul(x, w)in your model. Under the hood, this binds a JAX custom primitive that acts as a marker — the actual computation is standard JAX (x @ w).Compiler: When you call
compile_graph(), BrainTrace walks the JAX intermediate representation (jaxpr), finds all ETP primitives, and connects each one to its associated hidden states and parameters.Executor: During the forward pass, the executor computes the model output and the Jacobians needed for eligibility trace updates.
Algorithms (top layer):
D_RTRLorES_D_RTRLuse the executor outputs to maintain eligibility traces and provide correct gradients viacustom_vjp.
3. Key Concept: Primitive-Based Parameter Selection#
The central design idea of BrainTrace is that the operation you use determines whether a parameter participates in online learning:
What you write |
Effect |
|---|---|
|
|
|
|
There is no need for special parameter classes. All weights are plain brainstate.ParamState. The choice of operation is what matters.
Here is a concrete example:
import jax
import jax.numpy as jnp
import brainstate
import braintrace
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
class SimpleRNN(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.w_in = brainstate.ParamState(brainstate.random.randn(n_in, n_rec) * 0.01)
self.w_rec = brainstate.ParamState(brainstate.random.randn(n_rec, n_rec) * 0.01)
self.w_out = brainstate.ParamState(brainstate.random.randn(n_rec, n_out) * 0.01)
self.h = brainstate.ShortTermState(jnp.zeros(n_rec))
def update(self, x):
# Regular matmul: w_in excluded from online learning
inp = x @ self.w_in.value
# ETP matmul: w_rec included in online learning
rec = braintrace.matmul(self.h.value, self.w_rec.value)
self.h.value = jax.nn.tanh(inp + rec)
# Regular matmul: w_out excluded from online learning
return self.h.value @ self.w_out.value
In the model above:
w_inandw_outuse standardx @ w— they receive only instantaneous gradients (no temporal credit assignment through eligibility traces).w_recusesbraintrace.matmul(h, w_rec)— the compiler will automatically maintain eligibility traces for this weight, enabling gradient computation that accounts for temporal dependencies.
All three weights are the same type (brainstate.ParamState). The operation is the only difference.
4. Using braintrace.nn Modules#
While you can use primitives directly (as shown above), BrainTrace provides pre-built layers in the braintrace.nn module that already use ETP primitives internally. These are drop-in replacements for standard brainstate.nn layers:
Module |
Description |
|---|---|
|
Dense linear layer using |
|
Linear layer with sign-constrained weights (E/I networks) |
|
Weight-standardized linear layer |
|
Linear layer with sparse connectivity (uses |
|
Low-rank adapter layer (uses |
|
Convolutional layers using |
|
Recurrent cells with ETP-aware gates |
|
Rate-coded SNN readout |
|
Normalisation layers |
For the matching low-level API, the user-facing primitive functions are:
braintrace.matmul(x, w, bias=None)– dense matrix multiplicationbraintrace.element_wise(weight, fn=...)– element-wise weight ops (gating, learnable thresholds)braintrace.conv(x, kernel, bias, ...)– convolutionbraintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None)– sparse matmulbraintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None)– LoRA decomposition
Use the braintrace.nn layers when you can; reach for the primitive functions when you need a custom layer that participates in online learning.
class GRUNet(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.rnn = braintrace.nn.GRUCell(n_in, n_rec)
self.readout = braintrace.nn.Linear(n_rec, n_out)
def update(self, x):
return self.readout(self.rnn(x))
The GRUCell internally uses braintrace.matmul for its weight operations, so all its recurrent parameters automatically participate in online learning. The Linear readout also uses ETP primitives, but the compiler will detect that it is not connected to any hidden state and handle it appropriately.
5. Online Learning in 3 Steps#
Using BrainTrace for online learning follows a simple three-step workflow:
Define the model using
braintrace.nnmodules or manual ETP primitives.Wrap with an algorithm and compile — choose
D_RTRLorES_D_RTRLand callcompile_graph().Train with standard JAX gradient computation — eligibility traces are updated inside the wrapped model call.
Here is the complete workflow:
# Step 1: Define model
model = GRUNet(10, 64, 10)
brainstate.nn.init_all_states(model)
GRUNet(
rnn=GRUCell(
in_size=(10,),
out_size=(64,),
state_initializer=ZeroInit(unit=1),
activation=<function tanh at 0x78140b136980>,
Wz=Linear(
in_size=(74,),
out_size=(64,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[64], weak_type=True),
'weight': ShapedArray(float32[74,64])
}
)
),
Wr=Linear(
in_size=(74,),
out_size=(64,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[64], weak_type=True),
'weight': ShapedArray(float32[74,64])
}
)
),
Wh=Linear(
in_size=(74,),
out_size=(64,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[64], weak_type=True),
'weight': ShapedArray(float32[74,64])
}
)
),
h=HiddenState(
value=ShapedArray(float32[64], weak_type=True)
)
),
readout=Linear(
in_size=(64,),
out_size=(10,),
w_mask=None,
weight=ParamState(
value={
'bias': ShapedArray(float32[10]),
'weight': ShapedArray(float32[64,10])
}
)
)
)
# Step 2: Wrap with D-RTRL and compile
trainer = braintrace.D_RTRL(model)
trainer.compile_graph(jnp.zeros(10)) # provide an example input for shape inference
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('rnn', 'Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mv). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.
_emit_no_relation_diag(
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('readout', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.
_emit_no_relation_diag(
# Step 3: Use standard JAX gradient computation
# The eligibility traces are updated inside trainer(x)
weights = model.states(brainstate.ParamState)
def loss_fn(x):
out = trainer(x)
return jnp.mean(out ** 2)
grad_fn = brainstate.transform.grad(loss_fn, weights)
grads = grad_fn(jnp.ones(10))
What happens under the hood:
compile_graph()traces the model through JAX, identifies all ETP primitives, and builds the eligibility trace computation graph.Each call to
trainer(x)runs the model forward pass and updates all eligibility traces.When you compute
grad(loss_fn, weights), the algorithm usescustom_vjpto provide gradients that incorporate the eligibility trace information — giving you temporally-aware gradients with \(O(1)\) memory per step.
6. Available Algorithms#
BrainTrace provides two main online learning algorithms:
Algorithm |
Class |
Memory Complexity |
Compute Complexity |
Best For |
|---|---|---|---|---|
D-RTRL |
|
\(O(B \cdot \theta)\) |
\(O(B \cdot I \cdot O)\) per layer |
RNNs (GRU, LSTM) with moderate hidden sizes |
ES-D-RTRL |
|
\(O(B \cdot N)\) |
\(O(B \cdot N)\) per layer |
Large-scale SNNs where \(N\) is the number of neurons |
Where \(B\) is the batch size, \(\theta\) is the number of parameters, \(I\) and \(O\) are input/output dimensions, and \(N\) is the number of neurons.
When to use which?#
D_RTRL(alsoParamDimVjpAlgorithm): Use for rate-based RNNs (GRU, LSTM, etc.) where you need accurate temporal gradient propagation. Its \(O(\theta)\) memory cost scales with parameter count, which is acceptable for typical RNN hidden sizes.ES_D_RTRL(alsoIODimVjpAlgorithm): Use for spiking neural networks (SNNs) or very large recurrent networks. It achieves \(O(N)\) complexity by exploiting the element-wise nature of neuronal dynamics, making it much more efficient for networks with many neurons.
Both algorithms are used in the same way — just swap the class name:
# D-RTRL for RNNs
trainer = braintrace.D_RTRL(model)
# ES-D-RTRL for SNNs
trainer = braintrace.ES_D_RTRL(model)
7. Summary#
Here is a quick recap of the core concepts:
Concept |
Description |
|---|---|
Online learning |
Update weights at each time step using eligibility traces, achieving \(O(1)\) memory per step |
ETP primitives |
|
Primitive-based selection |
Use an ETP primitive to include a weight; use regular JAX ops to exclude it |
|
Pre-built layers (Linear, GRUCell, LSTMCell, Conv) that use ETP primitives internally |
Compile step |
|
D_RTRL |
\(O(\theta)\) algorithm for RNNs |
ES_D_RTRL |
\(O(N)\) algorithm for SNNs |
8. Next Steps#
Now that you understand the core concepts, explore the following tutorials:
RNN Online Learning: A complete example of training a GRU on the copying task using
D_RTRL, including comparison with BPTT.SNN Online Learning: Training spiking neural networks with
ES_D_RTRLon neuromorphic datasets.ETP Primitives Deep Dive: Detailed guide on using and extending ETP primitives for custom operations.
Batching: How to handle batched inputs with online learning.
Visualizing the Computation Graph: Inspect the compiled eligibility trace graph for debugging.