Custom Algorithm Development#

This advanced tutorial covers how to develop custom online learning algorithms with braintrace.

braintrace ships with two built-in algorithms:

  • D_RTRL (ParamDimVjpAlgorithm): Full eligibility traces with parameter-dimension complexity.

  • ES_D_RTRL (IODimVjpAlgorithm): Factorized eligibility traces with input/output-dimension complexity.

For research purposes, you may want to implement custom online learning algorithms — for example, adding trace clipping, spectral normalization of Jacobians, or entirely different trace update rules. braintrace’s algorithm hierarchy makes this straightforward.

The algorithm class hierarchy:

ETraceAlgorithm              # Base: model wrapping, graph compilation, state separation
  └─ ETraceVjpAlgorithm       # Adds VJP-based gradient computation, implements update()
       ├─ ParamDimVjpAlgorithm # D-RTRL: traces stored per weight parameter
       └─ IODimVjpAlgorithm    # ES-D-RTRL: traces factorized into input/output components

All algorithms share the same graph compilation infrastructure: ModuleInfo, HiddenGroup, HiddenPerturbation, and ETraceGraph. You only need to customize the trace update and gradient computation logic.

Algorithm Architecture#

ETraceAlgorithm (Base Class)#

The root of the hierarchy. It handles:

  • Model wrapping: Stores the target brainstate.nn.Module and its graph executor.

  • Graph compilation: Calls compile_graph(*args) to build the eligibility trace computation graph.

  • State separation: Splits model states into param_states (weights), hidden_states (recurrent states), and other_states.

  • Running index: Tracks the current time step via self.running_index.

Key abstract methods: init_etrace_state(), update(), get_etrace_of().

ETraceVjpAlgorithm#

Extends ETraceAlgorithm with VJP-based (reverse-mode) gradient computation. It:

  • Defines a custom_vjp function that wraps the forward pass and eligibility trace update.

  • Implements the update() method, which extracts state values, calls the forward+trace update, and writes states back.

  • Provides the backward pass (_update_fn_bwd) that computes weight gradients from eligibility traces and loss gradients.

Key protocol methods (to be overridden by subclasses):

  • _update_etrace_data(): Core trace update logic.

  • _solve_weight_gradients(): Compute final weight gradients from traces and loss gradients.

  • _get_etrace_data(): Retrieve current trace values from states.

  • _assign_etrace_data(): Write trace values back to states.

ParamDimVjpAlgorithm (D-RTRL)#

Stores one eligibility trace tensor per (weight, hidden group) pair. The trace has the same shape as the weight parameter (times the number of hidden states). Memory complexity: \(O(B\theta)\) where \(\theta\) is the parameter count and \(B\) the batch size.

IODimVjpAlgorithm (ES-D-RTRL)#

Factorizes the eligibility trace into separate input traces (\(\epsilon_x\)) and output/transition traces (\(\epsilon_f\)). This reduces memory to \(O(B(I+O))\) where \(I\) and \(O\) are the input and output dimensions. Controlled by a decay_or_rank parameter.

D-RTRL Mathematical Foundation#

The D-RTRL algorithm maintains an eligibility trace \(\epsilon^t\) that approximates the full Jacobian \(\partial h^t / \partial \theta\) using a diagonal approximation of the hidden-to-hidden Jacobian:

\[ \epsilon^t \approx D^t \, \epsilon^{t-1} + \text{diag}(D_f^t) \otimes x^t \]

where:

  • \(D^t = \text{diag}(\partial h^t / \partial h^{t-1})\): the hidden-to-hidden Jacobian (diagonal approximation)

  • \(D_f^t = \partial h^t / \partial y^t\): the transition Jacobian, where \(y^t\) is the output of the weight operation

  • \(x^t\): the input to the weight operation at time \(t\)

  • \(\otimes\): the outer product

The weight gradient is then computed by combining the eligibility traces with the loss gradient:

\[ \nabla_\theta \mathcal{L} = \sum_{t' \in \mathcal{T}} \frac{\partial \mathcal{L}^{t'}}{\partial h^{t'}} \circ \epsilon^{t'} \]

where \(\circ\) denotes the contraction over hidden dimensions that produces the weight-shaped gradient.

ES-D-RTRL Mathematical Foundation#

ES-D-RTRL further approximates the D-RTRL trace by factorizing it into separate input and output components:

\[ \epsilon^t \approx \epsilon_f^t \otimes \epsilon_x^t \]

The two components are updated with exponential smoothing controlled by a decay factor \(\alpha\):

Input trace: $\( \epsilon_x^t = \alpha \, \epsilon_x^{t-1} + x^t \)$

Output trace: $\( \epsilon_f^t = \alpha \, D^t \circ \epsilon_f^{t-1} + (1 - \alpha) \, D_f^t \)$

where:

  • \(\alpha \in (0, 1)\): exponential smoothing decay factor

  • \(D^t\): hidden-to-hidden diagonal Jacobian (same as in D-RTRL)

  • \(D_f^t\): transition Jacobian

  • \(x^t\): input to the weight operation

The decay factor \(\alpha\) is controlled by the decay_or_rank parameter:

  • If a float in \((0, 1)\): used directly as \(\alpha\).

  • If an int \(> 0\) (the approximation rank): \(\alpha = (\text{rank} - 1) / (\text{rank} + 1)\).

The weight gradient formula is the same as D-RTRL, but uses the factorized traces.

Key Methods to Override#

When implementing a custom algorithm, you typically subclass ParamDimVjpAlgorithm or IODimVjpAlgorithm and override one or more of these methods:

Method

Purpose

Defined in

init_etrace_state(*args)

Initialize trace storage (called during compile_graph)

ETraceAlgorithm

_get_etrace_data()

Retrieve current trace values from internal states

ETraceVjpAlgorithm

_assign_etrace_data(vals)

Write trace values back to internal states

ETraceVjpAlgorithm

_update_etrace_data(...)

Core trace update logic (the trace recurrence equation)

ETraceVjpAlgorithm

_solve_weight_gradients(...)

Compute final weight gradients from traces + loss gradients

ETraceVjpAlgorithm

reset_state(batch_size)

Reset traces between epochs/episodes

ParamDimVjpAlgorithm / IODimVjpAlgorithm

Method signatures#

def _update_etrace_data(
    self,
    running_index,            # int: current time step
    etrace_vals_util_t_1,     # ETraceVals: traces accumulated until t-1
    hid2weight_jac,           # Hid2WeightJacobian: current Jacobians
    hid2hid_jac,              # Sequence[jax.Array]: hidden-to-hidden Jacobians
    weight_vals,              # Dict[Path, PyTree]: current weight values
    input_is_multi_step,      # bool: whether input spans multiple steps
) -> ETraceVals:
    ...

def _solve_weight_gradients(
    self,
    running_index,            # int: current time step
    etrace_h2w_at_t,          # eligibility trace data at time t
    dl_to_hidden_groups,      # Sequence[jax.Array]: dL/dh per hidden group
    weight_vals,              # Dict[WeightID, PyTree]: current weight values
    dl_to_nonetws_at_t,       # gradients of non-etrace parameters
    dl_to_etws_at_t,          # optional gradients of etrace parameters
) -> Dict[Path, PyTree]:
    ...

Example: Implementing a Clipped D-RTRL#

A common issue with eligibility traces in deep recurrent networks is trace explosion: the trace magnitudes grow unboundedly over time. One practical mitigation is to clip the trace values after each update.

Below we implement ClippedDRTRL, which inherits from ParamDimVjpAlgorithm and applies element-wise clipping to the updated traces.

import jax
import jax.numpy as jnp
import brainstate
import braintrace
from braintrace._etrace_algorithms.d_rtrl import ParamDimVjpAlgorithm


class ClippedDRTRL(ParamDimVjpAlgorithm):
    """D-RTRL with trace clipping for stability.

    After each trace update, all trace values are clipped to
    [-clip_value, +clip_value] to prevent trace explosion in
    deep or long-horizon recurrent networks.
    """

    def __init__(self, model, clip_value=1.0, **kwargs):
        super().__init__(model, **kwargs)
        self.clip_value = clip_value

    def _update_etrace_data(
        self,
        running_index,
        hist_etrace_vals,
        hid2weight_jac,
        hid2hid_jac,
        weight_vals,
        input_is_multi_step,
    ):
        # Call parent's trace update (standard D-RTRL recurrence)
        new_traces = super()._update_etrace_data(
            running_index,
            hist_etrace_vals,
            hid2weight_jac,
            hid2hid_jac,
            weight_vals,
            input_is_multi_step,
        )
        # Clip the traces element-wise
        new_traces = jax.tree.map(
            lambda t: jnp.clip(t, -self.clip_value, self.clip_value),
            new_traces,
        )
        return new_traces
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Now we can use ClippedDRTRL exactly like the built-in D_RTRL:

# Create a simple RNN model
model = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)
brainstate.nn.init_all_states(model)

# Instantiate our custom algorithm with clip_value=5.0
algo = ClippedDRTRL(model, clip_value=5.0)

# Compile the computation graph with a dummy input
algo.compile_graph(jnp.zeros(10))

print("ClippedDRTRL compiled successfully.")
print(f"Number of param states: {len(algo.param_states)}")
print(f"Number of hidden states: {len(algo.hidden_states)}")
ClippedDRTRL compiled successfully.
Number of param states: 1
Number of hidden states: 1
# Run a few forward steps
for step in range(5):
    out = algo(jnp.ones(10))
    print(f"Step {step}: output shape = {out.shape}")
Step 0: output shape = (32,)
Step 1: output shape = (32,)
Step 2: output shape = (32,)
Step 3: output shape = (32,)
Step 4: output shape = (32,)

Understanding the Update Flow#

At each time step, the algorithm performs the following sequence:

  1. Forward pass: The graph executor runs the model to produce the output, new hidden states, and new other states.

  2. Jacobian computation: The executor also computes:

    • h2w Jacobians: \(\partial h^t / \partial y^t\) and \(x^t\) (how the hidden state changes w.r.t. the weight operation output and input).

    • h2h Jacobians: \(\partial h^t / \partial h^{t-1}\) (the recurrent Jacobian, diagonal approximation).

  3. Trace update: _update_etrace_data() uses the Jacobians and the previous traces to compute the new eligibility traces via the chosen recurrence equation.

  4. Gradient computation (on backward pass): When jax.grad or brainstate.transform.grad is applied, the custom_vjp backward pass calls _solve_weight_gradients() to combine the traces with the loss-to-hidden gradient \(\partial L / \partial h^t\) to produce parameter gradients.

The key insight is that ETraceVjpAlgorithm.update() wraps the forward pass + trace update inside a jax.custom_vjp function. This means:

  • Forward: The model runs normally, and traces are updated as a side effect.

  • Backward: Instead of backpropagating through the entire recurrent history (as BPTT would), the custom VJP uses the eligibility traces to directly produce weight gradients at the current time step.

This is what makes online learning possible: gradients are computed on-the-fly without storing the full computation history.

                      forward pass
                     ┌─────────────────────────────────────────┐
                     │                                         │
  x^t ────────▶ graph_executor.solve_h2w_h2h_jacobian()   │
  h^{t-1} ───▶     │  ──▶ output, h^t, h2w_jac, h2h_jac       │
  weights ────▶     │                                         │
                     └─────────────────┬───────────────────────┘
                                     │
                                     ▼
                      trace update
                     ┌─────────────────────────────────────────┐
  ε^{t-1} ────▶     │                                         │
  h2w_jac ────▶ _update_etrace_data()  ──▶ ε^t           │
  h2h_jac ────▶     │                                         │
                     └─────────────────────────────────────────┘
                                     │
                    (on backward)     ▼
                     ┌─────────────────────────────────────────┐
  dL/dh^t ────▶     │                                         │
  ε^t ────────▶ _solve_weight_gradients()  ──▶ dL/dθ      │
                     │                                         │
                     └─────────────────────────────────────────┘

Accessing Eligibility Traces#

After running the algorithm for a few steps, you can inspect the eligibility traces for any weight parameter. This is useful for debugging and analysis.

# Create a fresh model and algorithm
model2 = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)
brainstate.nn.init_all_states(model2)

algo2 = braintrace.D_RTRL(model2)
algo2.compile_graph(jnp.zeros(10))

# Run a few steps to build up traces
for _ in range(5):
    algo2(jnp.ones(10))

# Inspect eligibility traces for each weight parameter
weights = model2.states(brainstate.ParamState)
for path, w in weights.items():
    traces = algo2.get_etrace_of(w)
    for key, trace_val in traces.items():
        print(f"Weight {path}, trace key: {key}")
        print(f"  trace shape: {trace_val.shape}")
        print(f"  trace abs max: {jnp.abs(trace_val).max():.6f}")
Weight ('W', 'weight'), trace key: (('W', 'weight'), 130753575862208, 'hidden_group_0')
  trace shape: (42, 32, 1)
  trace abs max: 5.605252

Customising the Gradient Solve#

The _update_etrace_data() override above shapes how traces are built up. To shape how the final weight gradients are read out of those traces, override _solve_weight_gradients(). A common use is global gradient norm clipping so that one outlier weight cannot blow up the optimizer step.

The override pattern is the same as before – delegate to super()._solve_weight_gradients(...) for the standard contraction, then transform the dict it returns. The signature is:

def _solve_weight_gradients(
    self,
    running_index: int,
    etrace_h2w_at_t,                # dict[(weight_id, hidden_id), pytree]
    dl_to_hidden_groups,            # Sequence[jax.Array]: dL/dh per group
    weight_vals,                    # dict[weight_id, pytree]: current weights
    dl_to_nonetws_at_t,             # dict[path, pytree]: non-ETP grads
    dl_to_etws_at_t,                # Optional[dict[path, pytree]]: ETP shortcut grads
) -> dict[Path, PyTree]:
    ...

The returned dict maps every parameter path (ETP and non-ETP alike) to its gradient pytree.

class GradClippedDRTRL(braintrace.D_RTRL):
    """D-RTRL that clips the global gradient norm after solving."""

    def __init__(self, model, max_norm: float = 1.0, **kwargs):
        super().__init__(model, **kwargs)
        self.max_norm = max_norm

    def _solve_weight_gradients(
        self,
        running_index,
        etrace_h2w_at_t,
        dl_to_hidden_groups,
        weight_vals,
        dl_to_nonetws_at_t,
        dl_to_etws_at_t,
    ):
        # Standard contraction: traces x dL/dh -> per-weight gradients.
        grads = super()._solve_weight_gradients(
            running_index,
            etrace_h2w_at_t,
            dl_to_hidden_groups,
            weight_vals,
            dl_to_nonetws_at_t,
            dl_to_etws_at_t,
        )
        # Compute global L2 norm across all leaves of the gradient dict.
        sq = jax.tree.reduce(
            lambda acc, g: acc + jnp.sum(g * g),
            grads,
            initializer=jnp.zeros((), dtype=jnp.float32),
        )
        norm = jnp.sqrt(sq)
        scale = jnp.minimum(1.0, self.max_norm / (norm + 1e-12))
        return jax.tree.map(lambda g: g * scale, grads)


# Smoke-test that the new algorithm compiles and produces gradients.
clipped_model = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)
brainstate.nn.init_all_states(clipped_model)
clipped_algo = GradClippedDRTRL(clipped_model, max_norm=2.0)
clipped_algo.compile_graph(jnp.zeros(10))
print("GradClippedDRTRL compiled OK; max_norm =", clipped_algo.max_norm)
GradClippedDRTRL compiled OK; max_norm = 2.0

Resetting Trace State Between Episodes#

Eligibility traces are stored on the algorithm as EligibilityTrace instances (a thin subclass of brainstate.ShortTermState). Between epochs, sequences, or evaluation runs you typically want to zero them so the next sequence starts from a clean state. Both D_RTRL and ES_D_RTRL expose reset_state(batch_size=None) for this.

reset_state does two things: it resets the algorithm’s running_index counter to 0, and it zeros every EligibilityTrace (re-broadcasting to the requested batch_size if given). It does not touch the model’s hidden states – call brainstate.nn.reset_all_states(model) for those. Override reset_state only when your custom algorithm carries extra state (e.g. a momentum accumulator) that must also be cleared.

# Run a few steps so the traces are non-zero, reset, then verify.
for _ in range(3):
    clipped_algo(jnp.ones(10))

# Show that traces are non-zero before reset.
weights = clipped_model.states(brainstate.ParamState)
sample_weight = next(iter(weights.values()))
sample_trace = next(iter(clipped_algo.get_etrace_of(sample_weight).values()))
print(f"Trace abs-max before reset: {jnp.abs(sample_trace).max():.6f}")

# Reset and verify the traces are now zero (and running_index == 0).
clipped_algo.reset_state()
sample_trace = next(iter(clipped_algo.get_etrace_of(sample_weight).values()))
print(f"Trace abs-max after  reset: {jnp.abs(sample_trace).max():.6f}")
print(f"running_index after  reset: {int(clipped_algo.running_index.value)}")
Trace abs-max before reset: 2.913480
Trace abs-max after  reset: 0.000000
running_index after  reset: 0

Summary#

This tutorial covered the architecture and extension points for developing custom online learning algorithms in braintrace.

Key takeaways:

  • Extend ParamDimVjpAlgorithm for custom algorithms that maintain full-dimensional traces (like D-RTRL). Extend IODimVjpAlgorithm for custom algorithms that use factorized traces (like ES-D-RTRL).

  • Override _update_etrace_data() to implement custom trace dynamics (e.g., clipping, normalization, decay schedules).

  • Override _solve_weight_gradients() to transform the gradient dict produced by the standard contraction – e.g. global gradient-norm clipping (GradClippedDRTRL above), per-layer scaling, or momentum.

  • The graph compilation infrastructure is shared across all algorithms. You do not need to re-implement the model tracing, Jacobian computation, or state management — only the trace update and gradient computation.

  • reset_state(batch_size=None) zeros every EligibilityTrace and resets running_index to 0; override it only when your algorithm carries extra state that must also be cleared. Override init_etrace_state() if your algorithm needs new trace-storage shapes at compile time.

  • Use get_etrace_of(weight) to inspect trace values at any point during training, which is valuable for debugging and research analysis.