Core Concepts of BrainTrace#

Welcome to BrainTrace! This notebook introduces the core concepts you need to understand before using the library for online learning in recurrent and spiking neural networks.

BrainTrace is built on JAX and brainstate, providing memory-efficient online learning through eligibility trace propagation.

1. What is Online Learning?#

Training recurrent neural networks (RNNs) typically relies on Backpropagation Through Time (BPTT). BPTT unrolls the full computation graph over all time steps before computing gradients:

  • BPTT stores the entire computation graph across \(T\) time steps, requiring \(O(T)\) memory.

  • As sequence length grows, memory usage becomes a bottleneck.

Online learning takes a different approach:

  • Weights are updated at each time step using eligibility traces that summarize the gradient history.

  • Memory cost is \(O(1)\) per time step (independent of sequence length).

  • Eligibility traces accumulate the information needed for gradient computation incrementally.

BrainTrace implements online learning via JAX custom primitives. Instead of relying on string-matching or special parameter wrappers, BrainTrace identifies which operations participate in online learning by their primitive type at the JAX IR level. This gives a clean, composable, and JIT-friendly design.

2. Architecture Overview#

BrainTrace is organized as a 4-layer system. Each layer builds on the one below it:

+--------------------------------------------------------------+
|  Algorithms    braintrace.D_RTRL / braintrace.ES_D_RTRL      |
|                trace update + custom_vjp for jax.grad         |
+--------------------------------------------------------------+
|  Executor      ETraceGraphExecutor                            |
|                forward pass + Jacobian computation            |
+--------------------------------------------------------------+
|  Compiler      compile_etrace_graph()                         |
|                jaxpr walk -> find primitives -> connect to     |
|                hidden states                                  |
+--------------------------------------------------------------+
|  Primitives    braintrace.matmul / element_wise / conv        |
|  & Functions   JAX custom primitives (thin markers)           |
+--------------------------------------------------------------+

How it works:

  1. Primitives & Functions (bottom layer): You call braintrace.matmul(x, w) in your model. Under the hood, this binds a JAX custom primitive that acts as a marker — the actual computation is standard JAX (x @ w).

  2. Compiler: When you call compile_graph(), BrainTrace walks the JAX intermediate representation (jaxpr), finds all ETP primitives, and connects each one to its associated hidden states and parameters.

  3. Executor: During the forward pass, the executor computes the model output and the Jacobians needed for eligibility trace updates.

  4. Algorithms (top layer): D_RTRL or ES_D_RTRL use the executor outputs to maintain eligibility traces and provide correct gradients via custom_vjp.

3. Key Concept: Primitive-Based Parameter Selection#

The central design idea of BrainTrace is that the operation you use determines whether a parameter participates in online learning:

What you write

Effect

braintrace.matmul(x, w)

w is included in online learning (eligibility traces are maintained)

x @ w (regular JAX matmul)

w is excluded from online learning (only instantaneous gradients)

There is no need for special parameter classes. All weights are plain brainstate.ParamState. The choice of operation is what matters.

Here is a concrete example:

import jax
import jax.numpy as jnp
import brainstate
import braintrace
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
class SimpleRNN(brainstate.nn.Module):
    def __init__(self, n_in, n_rec, n_out):
        super().__init__()
        self.w_in = brainstate.ParamState(brainstate.random.randn(n_in, n_rec) * 0.01)
        self.w_rec = brainstate.ParamState(brainstate.random.randn(n_rec, n_rec) * 0.01)
        self.w_out = brainstate.ParamState(brainstate.random.randn(n_rec, n_out) * 0.01)
        self.h = brainstate.ShortTermState(jnp.zeros(n_rec))

    def update(self, x):
        # Regular matmul: w_in excluded from online learning
        inp = x @ self.w_in.value

        # ETP matmul: w_rec included in online learning
        rec = braintrace.matmul(self.h.value, self.w_rec.value)

        self.h.value = jax.nn.tanh(inp + rec)

        # Regular matmul: w_out excluded from online learning
        return self.h.value @ self.w_out.value

In the model above:

  • w_in and w_out use standard x @ w — they receive only instantaneous gradients (no temporal credit assignment through eligibility traces).

  • w_rec uses braintrace.matmul(h, w_rec) — the compiler will automatically maintain eligibility traces for this weight, enabling gradient computation that accounts for temporal dependencies.

All three weights are the same type (brainstate.ParamState). The operation is the only difference.

4. Using braintrace.nn Modules#

While you can use primitives directly (as shown above), BrainTrace provides pre-built layers in the braintrace.nn module that already use ETP primitives internally. These are drop-in replacements for standard brainstate.nn layers:

Module

Description

braintrace.nn.Linear

Dense linear layer using braintrace.matmul

braintrace.nn.SignedWLinear

Linear layer with sign-constrained weights (E/I networks)

braintrace.nn.ScaledWSLinear

Weight-standardized linear layer

braintrace.nn.SparseLinear

Linear layer with sparse connectivity (uses sparse_matmul)

braintrace.nn.LoRA

Low-rank adapter layer (uses lora_matmul)

braintrace.nn.Conv1d / Conv2d / Conv3d

Convolutional layers using braintrace.conv

braintrace.nn.GRUCell / LSTMCell / ValinaRNNCell

Recurrent cells with ETP-aware gates

braintrace.nn.LeakyRateReadout

Rate-coded SNN readout

braintrace.nn.BatchNorm1d / LayerNorm

Normalisation layers

For the matching low-level API, the user-facing primitive functions are:

  • braintrace.matmul(x, w, bias=None) – dense matrix multiplication

  • braintrace.element_wise(weight, fn=...) – element-wise weight ops (gating, learnable thresholds)

  • braintrace.conv(x, kernel, bias, ...) – convolution

  • braintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None) – sparse matmul

  • braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None) – LoRA decomposition

Use the braintrace.nn layers when you can; reach for the primitive functions when you need a custom layer that participates in online learning.

class GRUNet(brainstate.nn.Module):
    def __init__(self, n_in, n_rec, n_out):
        super().__init__()
        self.rnn = braintrace.nn.GRUCell(n_in, n_rec)
        self.readout = braintrace.nn.Linear(n_rec, n_out)

    def update(self, x):
        return self.readout(self.rnn(x))

The GRUCell internally uses braintrace.matmul for its weight operations, so all its recurrent parameters automatically participate in online learning. The Linear readout also uses ETP primitives, but the compiler will detect that it is not connected to any hidden state and handle it appropriately.

5. Online Learning in 3 Steps#

Using BrainTrace for online learning follows a simple three-step workflow:

  1. Define the model using braintrace.nn modules or manual ETP primitives.

  2. Wrap with an algorithm and compile — choose D_RTRL or ES_D_RTRL and call compile_graph().

  3. Train with standard JAX gradient computation — eligibility traces are updated inside the wrapped model call.

Here is the complete workflow:

# Step 1: Define model
model = GRUNet(10, 64, 10)
brainstate.nn.init_all_states(model)
GRUNet(
  rnn=GRUCell(
    in_size=(10,),
    out_size=(64,),
    state_initializer=ZeroInit(unit=1),
    activation=<function tanh at 0x78140b136980>,
    Wz=Linear(
      in_size=(74,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64], weak_type=True),
          'weight': ShapedArray(float32[74,64])
        }
      )
    ),
    Wr=Linear(
      in_size=(74,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64], weak_type=True),
          'weight': ShapedArray(float32[74,64])
        }
      )
    ),
    Wh=Linear(
      in_size=(74,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64], weak_type=True),
          'weight': ShapedArray(float32[74,64])
        }
      )
    ),
    h=HiddenState(
      value=ShapedArray(float32[64], weak_type=True)
    )
  ),
  readout=Linear(
    in_size=(64,),
    out_size=(10,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[10]),
        'weight': ShapedArray(float32[64,10])
      }
    )
  )
)
# Step 2: Wrap with D-RTRL and compile
trainer = braintrace.D_RTRL(model)
trainer.compile_graph(jnp.zeros(10))  # provide an example input for shape inference
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('rnn', 'Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mv). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.
  _emit_no_relation_diag(
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('readout', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.
  _emit_no_relation_diag(
# Step 3: Use standard JAX gradient computation
# The eligibility traces are updated inside trainer(x)
weights = model.states(brainstate.ParamState)

def loss_fn(x):
    out = trainer(x)
    return jnp.mean(out ** 2)

grad_fn = brainstate.transform.grad(loss_fn, weights)
grads = grad_fn(jnp.ones(10))

What happens under the hood:

  • compile_graph() traces the model through JAX, identifies all ETP primitives, and builds the eligibility trace computation graph.

  • Each call to trainer(x) runs the model forward pass and updates all eligibility traces.

  • When you compute grad(loss_fn, weights), the algorithm uses custom_vjp to provide gradients that incorporate the eligibility trace information — giving you temporally-aware gradients with \(O(1)\) memory per step.

6. Available Algorithms#

BrainTrace provides two main online learning algorithms:

Algorithm

Class

Memory Complexity

Compute Complexity

Best For

D-RTRL

braintrace.D_RTRL

\(O(B \cdot \theta)\)

\(O(B \cdot I \cdot O)\) per layer

RNNs (GRU, LSTM) with moderate hidden sizes

ES-D-RTRL

braintrace.ES_D_RTRL

\(O(B \cdot N)\)

\(O(B \cdot N)\) per layer

Large-scale SNNs where \(N\) is the number of neurons

Where \(B\) is the batch size, \(\theta\) is the number of parameters, \(I\) and \(O\) are input/output dimensions, and \(N\) is the number of neurons.

When to use which?#

  • D_RTRL (also ParamDimVjpAlgorithm): Use for rate-based RNNs (GRU, LSTM, etc.) where you need accurate temporal gradient propagation. Its \(O(\theta)\) memory cost scales with parameter count, which is acceptable for typical RNN hidden sizes.

  • ES_D_RTRL (also IODimVjpAlgorithm): Use for spiking neural networks (SNNs) or very large recurrent networks. It achieves \(O(N)\) complexity by exploiting the element-wise nature of neuronal dynamics, making it much more efficient for networks with many neurons.

Both algorithms are used in the same way — just swap the class name:

# D-RTRL for RNNs
trainer = braintrace.D_RTRL(model)

# ES-D-RTRL for SNNs
trainer = braintrace.ES_D_RTRL(model)

7. Summary#

Here is a quick recap of the core concepts:

Concept

Description

Online learning

Update weights at each time step using eligibility traces, achieving \(O(1)\) memory per step

ETP primitives

braintrace.matmul, braintrace.element_wise, braintrace.conv — JAX custom primitives that mark operations for online learning

Primitive-based selection

Use an ETP primitive to include a weight; use regular JAX ops to exclude it

braintrace.nn

Pre-built layers (Linear, GRUCell, LSTMCell, Conv) that use ETP primitives internally

Compile step

trainer.compile_graph(example_input) — analyzes the jaxpr to build the trace computation graph

D_RTRL

\(O(\theta)\) algorithm for RNNs

ES_D_RTRL

\(O(N)\) algorithm for SNNs

8. Next Steps#

Now that you understand the core concepts, explore the following tutorials: