Limitations & Workarounds#

Introduction#

braintrace analyzes the model’s Jaxpr (JAX’s intermediate representation) at compile time to automatically derive eligibility trace update rules. This compilation process walks through the traced computation graph to identify the relationships between hidden states, parameters, and the operations that connect them.

However, some JAX operations create sub-Jaxprs – separate, nested computation graphs – that the braintrace compiler cannot traverse. When such operations appear inside the model’s update() method, the compiler loses visibility into the computation and cannot correctly construct the eligibility trace graph.

Understanding these limitations helps you design models that are fully compatible with braintrace’s online learning compilation. This tutorial covers the known limitations and provides practical workarounds for each.

Unsupported JAX Primitives Inside the Model#

The following JAX control flow primitives are NOT supported inside the model’s update() method:

Primitive

Description

Why it fails

jax.lax.cond

Conditional execution (if/else)

Creates two branch sub-Jaxprs

jax.lax.scan

Loop with carry state

Creates a body sub-Jaxpr

jax.lax.while_loop

General loops

Creates cond + body sub-Jaxprs

jax.vmap

Vectorized map (nested inside model)

Creates a mapped sub-Jaxpr

Each of these constructs introduces a sub-Jaxpr that the braintrace compiler cannot analyze. When the compiler encounters one of these primitives during graph construction, it will raise a NotSupportedError or CompilationError.

Important note: These primitives can still be used outside of the model’s update() method. For example, using jax.lax.scan to unroll the model over time steps is perfectly fine – the restriction only applies to operations within the traced computation that connects hidden states to parameters.

Example of Unsupported Code#

The following model uses jax.lax.cond inside its update() method. This will cause a compilation error because the conditional branches create sub-Jaxprs that the compiler cannot traverse.

import jax
import jax.numpy as jnp
import brainstate
import braintrace


# THIS WILL NOT WORK: using jax.lax.cond inside update()
class BadModel(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = brainstate.ParamState(jnp.ones((10, 10)))
        self.h = brainstate.HiddenState(jnp.zeros(10))

    def update(self, x):
        # BAD: jax.lax.cond creates a sub-Jaxpr that the compiler cannot analyze
        self.h.value = jax.lax.cond(
            jnp.sum(x) > 0,
            lambda: jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x),
            lambda: self.h.value,
        )
        return self.h.value
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

The compiler fails because when it traces update(), it sees a cond primitive whose true/false branches are opaque sub-Jaxprs. The braintrace.matmul call is hidden inside one of those branches, so the compiler cannot discover the relationship between self.w and self.h.

Workarounds for Conditional Logic#

When you need branch-like behaviour without a cond primitive, the goal is to choose between values without producing a sub-Jaxpr that the compiler will see in the hidden-state path.

Strategy 1: jax.lax.select#

jax.lax.select(predicate, on_true, on_false) is the lowest-level branch-free selection operator. It compiles directly to the select_n primitive – no jit, no cond, no sub-Jaxpr. Use it whenever the body of update() needs to pick between two precomputed values.

Note: in current JAX versions, jnp.where is wrapped in a jit of _where and the compiler treats that as a forbidden sub-Jaxpr when the result feeds a hidden state. Prefer jax.lax.select inside update().

# CORRECT: use jax.lax.select (no sub-Jaxpr) instead of jnp.where (which now jits).
class GoodModel(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = brainstate.ParamState(jnp.ones((10, 10)))
        self.h = brainstate.HiddenState(jnp.zeros(10))

    def update(self, x):
        new_h = jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x)
        # jax.lax.select compiles to a single select_n primitive -- the
        # compiler can trace right through it.
        self.h.value = jax.lax.select(jnp.sum(x) > 0, new_h, self.h.value)
        return self.h.value


model = GoodModel()
brainstate.nn.init_all_states(model)
algo = braintrace.D_RTRL(model)
algo.compile_graph(jnp.zeros(10))  # works
print("Compilation successful.")
Compilation successful.

Strategy 2: Multiplication by a mask#

For gating-style conditional logic, you can multiply by a binary mask instead of branching. This is particularly natural for spiking neural networks where spike masks are already available.

# Instead of: jax.lax.cond(spike, lambda: reset_value, lambda: current_value)
# Use:        current_value * (1 - spike) + reset_value * spike

Shape Compatibility Requirements#

The braintrace compiler requires that the output of an ETP primitive (e.g., braintrace.matmul) be shape-compatible with the target hidden state. “Compatible” means the shapes must match exactly or be broadcastable to each other.

The compiler checks this during relation construction: after identifying an ETP primitive and its associated weight, it traces forward through the Jaxpr to find reachable hidden-state output variables and filters by shape compatibility.

If the output of an ETP primitive passes through a shape-changing operation (such as slicing, indexing, or reshaping to an incompatible shape) before reaching the hidden state, the compiler will not be able to establish the connection.

# Shape mismatch example -- the weight won't be connected to the hidden state
class ShapeMismatch(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = brainstate.ParamState(jnp.ones((10, 20)))  # outputs dim 20
        self.h = brainstate.HiddenState(jnp.zeros(10))       # hidden dim 10

    def update(self, x):
        # The output shape (20,) doesn't match hidden shape (10,)
        # This weight won't be connected to the hidden state
        y = braintrace.matmul(x, self.w.value)
        self.h.value = y[:10]  # slicing breaks the connection
        return self.h.value

In this example, braintrace.matmul(x, self.w.value) produces a vector of dimension 20, but the hidden state self.h has dimension 10. The slicing operation y[:10] is not a simple broadcast – it fundamentally changes the shape, breaking the connection between the weight and the hidden state in the compiled graph.

Fix: Ensure that the weight matrix dimensions produce outputs that match the hidden state dimensions directly:

self.w = brainstate.ParamState(jnp.ones((10, 10)))  # outputs dim 10 to match hidden dim 10

The “Weight -> Weight -> Hidden” Invariant#

ETP rules are local to a single primitive: each primitive’s xy_to_dw rule assumes its x is externally-supplied data, and its yw_to_w rule assumes the path from this primitive’s output y to a hidden state h contains no other trainable ETP weights. If a primitive W1’s output flows through another non-gradient-enabled ETP primitive W2 before reaching h, the per-primitive rules cannot soundly account for W1 – that would either double-count the contribution from W2 or treat W2’s x as raw data when it is actually a function of W1.

The compiler enforces this by excluding W1 from the relation list whenever its only path to h passes through another non-gradient-enabled ETP primitive. The diagnostic kind is RELATION_EXCLUDED_WEIGHT_TO_WEIGHT and a UserWarning is emitted at compile time. The excluded weight is still trainable – but only via BPTT, not via online learning.

The classic example is braintrace.nn.GRUCell. It has three internal Linear layers (Wz, Wr, Wh), but the compiler records only two ETP relations:

  • Wz – output flows directly into the new hidden state. Included.

  • Wh – output flows directly into the new hidden state. Included.

  • Wr – output is consumed by Wh’s matmul (it gates r * old_h). Excluded with a RELATION_EXCLUDED_WEIGHT_TO_WEIGHT warning.

This is correct: Wr’s contribution to dL/dh is already implicit in Wh’s gradient (because Wh’s input depends on Wr), so adding Wr separately would double-count. To learn Wr online with this architecture, you would need to bundle Wr and Wh together – something per-primitive ETP cannot express.

When to use gradient_enabled=True#

The single exception is etp_elemwise_p – the only built-in primitive registered with gradient_enabled=True. Element-wise ops (gating biases, learnable thresholds, learnable time constants) are identity-like enough that they may sit on the tail of the y -> h walk without breaking the per-primitive assumption: an upstream ETP weight whose output passes through an element-wise op is still recorded as a relation.

When registering a custom primitive, leave gradient_enabled at its default False. Set it to True only if your primitive is genuinely identity-like (single weight, no input multiplication, the xy_to_dw rule is essentially a passthrough). Setting gradient_enabled=True on a trainable op – one with both an x and a w – silently re-enables the unsound double-counting and will produce wrong gradients.

# Concrete demonstration: GRUCell yields 2 ETP relations, not 3.
import jax.numpy as jnp
import brainstate
import braintrace

cell = braintrace.nn.GRUCell(in_size=4, out_size=8)
brainstate.nn.init_all_states(cell, batch_size=2)

graph = braintrace.compile_etrace_graph(cell, jnp.zeros((2, 4)))

print(f"GRUCell has {len(list(cell.states(brainstate.ParamState)))} ParamStates")
print(f"but the compiler recorded only {len(graph.hidden_param_op_relations)} ETP relations:")
for r in graph.hidden_param_op_relations:
    print(f"  - {r.weight_path}  (primitive: {r.primitive.name})")

# The third weight (Wr) shows up as a RELATION_EXCLUDED_WEIGHT_TO_WEIGHT diagnostic.
from braintrace import DiagnosticKind
excluded = [
    d for d in graph.diagnostics
    if d.kind == DiagnosticKind.RELATION_EXCLUDED_WEIGHT_TO_WEIGHT
]
print(f"\nWeight->weight exclusions: {len(excluded)}")
for d in excluded:
    print(f"  - {d.weight_path}: {d.message}")
GRUCell has 3 ParamStates
but the compiler recorded only 2 ETP relations:
  - ('Wz', 'weight')  (primitive: etp_mm)
  - ('Wh', 'weight')  (primitive: etp_mm)

Weight->weight exclusions: 1
  - ('Wr', 'weight'): ETP primitive etp_mm (weight=('Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mm). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mm (weight=('Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mm). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.
  _emit_no_relation_diag(

Performance Considerations#

Different online learning algorithms in braintrace have different memory and computational requirements. Choosing the right algorithm is important for scaling to larger models.

Memory complexity comparison#

Algorithm

Memory per weight

Total memory

Description

D_RTRL

O(B * weight_size * hidden_size)

O(B * |theta| * H)

Full eligibility traces

ES_D_RTRL

O(B * (in_size + out_size) * hidden_size)

O(B * (I+O) * H)

Factored eligibility traces

BPTT

O(T * model_size)

O(T * N)

Stores all activations over time

Where B = batch size, H = hidden state dimension, T = sequence length, N = total model size, I = input size, O = output size.

Key tradeoffs#

  • D_RTRL provides exact online gradients but can be memory-intensive for large weight matrices. The eligibility trace for each weight matrix has shape (weight_size, hidden_size), which grows quadratically with model size.

  • ES_D_RTRL (factored / IO-dimension algorithm) trades gradient accuracy for memory efficiency. Instead of storing the full eligibility trace, it factors the trace into input-dimension and output-dimension components, reducing memory from O(weight_size * hidden_size) to O((in_size + out_size) * hidden_size).

  • BPTT (Backpropagation Through Time) stores all intermediate activations over the unrolled time steps. Memory grows linearly with sequence length T, which can be prohibitive for long sequences.

Recommendations for large models#

  • Use ES_D_RTRL instead of D_RTRL when weight matrices are large

  • Reduce hidden state dimensions where possible

  • Use sparse operations (braintrace.sparse_matmul) to reduce the number of parameters

  • Consider using braintrace.lora_matmul for low-rank weight updates

Compilation Time#

The braintrace compiler performs several steps when compile_graph() is called:

  1. Jaxpr tracing: JAX traces the model’s update() method to produce a Jaxpr

  2. Relation discovery: The compiler walks the Jaxpr to find ETP primitives, trace weight origins, and connect them to hidden states

  3. Graph construction: The eligibility trace computation graph is built from the discovered relations

This compilation can be slow for complex models, especially on the first call. However:

  • Subsequent calls with the same input shapes reuse the compiled graph. The compilation result is cached, so you only pay the cost once.

  • compile_graph() should be called once before the training loop, not inside it. Calling it repeatedly with the same shapes is harmless (it detects the cache hit), but calling it inside a loop adds unnecessary overhead.

# Good: compile once, then run many steps
algo = braintrace.D_RTRL(model)
algo.compile_graph(example_input)

for step in range(num_steps):
    output = algo(input_data[step])  # uses cached compilation

What CAN Be Used Inside update()#

The braintrace compiler works with all standard JAX mathematical operations that do not create sub-Jaxprs. These include:

Standard math operations:

  • jnp.add, jnp.subtract, jnp.multiply, jnp.divide

  • Element-wise operators: +, -, *, /

Matrix operations:

  • @ (matrix multiply operator)

  • jnp.dot, jnp.matmul, jnp.einsum

Activation functions:

  • jax.nn.tanh, jax.nn.relu, jax.nn.sigmoid, jax.nn.softmax

  • jax.nn.silu, jax.nn.gelu, jax.nn.leaky_relu

Shape manipulation:

  • jnp.reshape, jnp.transpose, jnp.concatenate

  • jnp.expand_dims, jnp.squeeze

Selection and masking:

  • jax.lax.select(predicate, on_true, on_false) (preferred over jnp.where inside update(); see Workarounds above)

  • jnp.clip, jnp.maximum, jnp.minimum

Gradient control:

  • jax.lax.stop_gradient – useful for detaching parts of the computation

braintrace ETP primitives:

  • braintrace.matmul – matrix multiplication with ETP tracking

  • braintrace.element_wise – element-wise parameter operations with ETP tracking

  • braintrace.conv – convolution with ETP tracking

  • braintrace.sparse_matmul – sparse matrix multiplication with ETP tracking

  • braintrace.lora_matmul – LoRA-style low-rank multiplication with ETP tracking

In general, if a JAX operation compiles to a flat sequence of primitives in the Jaxpr (no nested sub-Jaxprs), it is compatible with braintrace.

Summary#

The key limitations and their workarounds are:

  1. Avoid cond, scan, while_loop, and nested vmap inside the model’s update() method. These create sub-Jaxprs that the compiler cannot traverse. Use them freely outside the model (e.g., for time-step unrolling).

  2. Use jnp.where and masks as alternatives to conditional logic. Element-wise selection operations are fully supported and produce equivalent results for most use cases.

  3. Ensure shape compatibility between ETP primitive outputs and hidden states. The compiler filters connections by shape – if shapes don’t match or broadcast, the connection won’t be established.

  4. Per-primitive ETP rules are local. A weight whose only path to a hidden state passes through another trainable ETP primitive is excluded with a RELATION_EXCLUDED_WEIGHT_TO_WEIGHT warning – it must be learned via BPTT or the architecture must be rewired. etp_elemwise_p (the only gradient_enabled=True built-in) is the sole exception.

  5. Choose the right algorithm based on memory/accuracy tradeoffs. Use D_RTRL for exact gradients with moderate model sizes, and ES_D_RTRL for memory-efficient approximate gradients with larger models.

  6. Call compile_graph() once before training, not inside the training loop. The compiled graph is cached and reused for inputs of the same shape.

  7. The compiler works with all standard JAX mathematical operations. As long as you avoid the unsupported control flow primitives listed above, your model will compile successfully.