Limitations & Workarounds#
Introduction#
braintrace analyzes the model’s Jaxpr (JAX’s intermediate representation) at compile time to automatically derive eligibility trace update rules. This compilation process walks through the traced computation graph to identify the relationships between hidden states, parameters, and the operations that connect them.
However, some JAX operations create sub-Jaxprs – separate, nested computation graphs – that the braintrace compiler cannot traverse. When such operations appear inside the model’s update() method, the compiler loses visibility into the computation and cannot correctly construct the eligibility trace graph.
Understanding these limitations helps you design models that are fully compatible with braintrace’s online learning compilation. This tutorial covers the known limitations and provides practical workarounds for each.
Unsupported JAX Primitives Inside the Model#
The following JAX control flow primitives are NOT supported inside the model’s update() method:
Primitive |
Description |
Why it fails |
|---|---|---|
|
Conditional execution (if/else) |
Creates two branch sub-Jaxprs |
|
Loop with carry state |
Creates a body sub-Jaxpr |
|
General loops |
Creates cond + body sub-Jaxprs |
|
Vectorized map (nested inside model) |
Creates a mapped sub-Jaxpr |
Each of these constructs introduces a sub-Jaxpr that the braintrace compiler cannot analyze. When the compiler encounters one of these primitives during graph construction, it will raise a NotSupportedError or CompilationError.
Important note: These primitives can still be used outside of the model’s update() method. For example, using jax.lax.scan to unroll the model over time steps is perfectly fine – the restriction only applies to operations within the traced computation that connects hidden states to parameters.
Example of Unsupported Code#
The following model uses jax.lax.cond inside its update() method. This will cause a compilation error because the conditional branches create sub-Jaxprs that the compiler cannot traverse.
import jax
import jax.numpy as jnp
import brainstate
import braintrace
# THIS WILL NOT WORK: using jax.lax.cond inside update()
class BadModel(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.w = brainstate.ParamState(jnp.ones((10, 10)))
self.h = brainstate.HiddenState(jnp.zeros(10))
def update(self, x):
# BAD: jax.lax.cond creates a sub-Jaxpr that the compiler cannot analyze
self.h.value = jax.lax.cond(
jnp.sum(x) > 0,
lambda: jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x),
lambda: self.h.value,
)
return self.h.value
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
The compiler fails because when it traces update(), it sees a cond primitive whose true/false branches are opaque sub-Jaxprs. The braintrace.matmul call is hidden inside one of those branches, so the compiler cannot discover the relationship between self.w and self.h.
Workarounds for Conditional Logic#
When you need branch-like behaviour without a cond primitive, the goal is to choose between values without producing a sub-Jaxpr that the compiler will see in the hidden-state path.
Strategy 1: jax.lax.select#
jax.lax.select(predicate, on_true, on_false) is the lowest-level branch-free selection operator. It compiles directly to the select_n primitive – no jit, no cond, no sub-Jaxpr. Use it whenever the body of update() needs to pick between two precomputed values.
Note: in current JAX versions,
jnp.whereis wrapped in ajitof_whereand the compiler treats that as a forbidden sub-Jaxpr when the result feeds a hidden state. Preferjax.lax.selectinsideupdate().
# CORRECT: use jax.lax.select (no sub-Jaxpr) instead of jnp.where (which now jits).
class GoodModel(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.w = brainstate.ParamState(jnp.ones((10, 10)))
self.h = brainstate.HiddenState(jnp.zeros(10))
def update(self, x):
new_h = jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x)
# jax.lax.select compiles to a single select_n primitive -- the
# compiler can trace right through it.
self.h.value = jax.lax.select(jnp.sum(x) > 0, new_h, self.h.value)
return self.h.value
model = GoodModel()
brainstate.nn.init_all_states(model)
algo = braintrace.D_RTRL(model)
algo.compile_graph(jnp.zeros(10)) # works
print("Compilation successful.")
Compilation successful.
Strategy 2: Multiplication by a mask#
For gating-style conditional logic, you can multiply by a binary mask instead of branching. This is particularly natural for spiking neural networks where spike masks are already available.
# Instead of: jax.lax.cond(spike, lambda: reset_value, lambda: current_value)
# Use: current_value * (1 - spike) + reset_value * spike
Shape Compatibility Requirements#
The braintrace compiler requires that the output of an ETP primitive (e.g., braintrace.matmul) be shape-compatible with the target hidden state. “Compatible” means the shapes must match exactly or be broadcastable to each other.
The compiler checks this during relation construction: after identifying an ETP primitive and its associated weight, it traces forward through the Jaxpr to find reachable hidden-state output variables and filters by shape compatibility.
If the output of an ETP primitive passes through a shape-changing operation (such as slicing, indexing, or reshaping to an incompatible shape) before reaching the hidden state, the compiler will not be able to establish the connection.
# Shape mismatch example -- the weight won't be connected to the hidden state
class ShapeMismatch(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.w = brainstate.ParamState(jnp.ones((10, 20))) # outputs dim 20
self.h = brainstate.HiddenState(jnp.zeros(10)) # hidden dim 10
def update(self, x):
# The output shape (20,) doesn't match hidden shape (10,)
# This weight won't be connected to the hidden state
y = braintrace.matmul(x, self.w.value)
self.h.value = y[:10] # slicing breaks the connection
return self.h.value
In this example, braintrace.matmul(x, self.w.value) produces a vector of dimension 20, but the hidden state self.h has dimension 10. The slicing operation y[:10] is not a simple broadcast – it fundamentally changes the shape, breaking the connection between the weight and the hidden state in the compiled graph.
Fix: Ensure that the weight matrix dimensions produce outputs that match the hidden state dimensions directly:
self.w = brainstate.ParamState(jnp.ones((10, 10))) # outputs dim 10 to match hidden dim 10
Performance Considerations#
Different online learning algorithms in braintrace have different memory and computational requirements. Choosing the right algorithm is important for scaling to larger models.
Memory complexity comparison#
Algorithm |
Memory per weight |
Total memory |
Description |
|---|---|---|---|
|
O(B * weight_size * hidden_size) |
O(B * |theta| * H) |
Full eligibility traces |
|
O(B * (in_size + out_size) * hidden_size) |
O(B * (I+O) * H) |
Factored eligibility traces |
BPTT |
O(T * model_size) |
O(T * N) |
Stores all activations over time |
Where B = batch size, H = hidden state dimension, T = sequence length, N = total model size, I = input size, O = output size.
Key tradeoffs#
D_RTRL provides exact online gradients but can be memory-intensive for large weight matrices. The eligibility trace for each weight matrix has shape
(weight_size, hidden_size), which grows quadratically with model size.ES_D_RTRL (factored / IO-dimension algorithm) trades gradient accuracy for memory efficiency. Instead of storing the full eligibility trace, it factors the trace into input-dimension and output-dimension components, reducing memory from O(weight_size * hidden_size) to O((in_size + out_size) * hidden_size).
BPTT (Backpropagation Through Time) stores all intermediate activations over the unrolled time steps. Memory grows linearly with sequence length T, which can be prohibitive for long sequences.
Recommendations for large models#
Use ES_D_RTRL instead of D_RTRL when weight matrices are large
Reduce hidden state dimensions where possible
Use sparse operations (
braintrace.sparse_matmul) to reduce the number of parametersConsider using
braintrace.lora_matmulfor low-rank weight updates
Compilation Time#
The braintrace compiler performs several steps when compile_graph() is called:
Jaxpr tracing: JAX traces the model’s
update()method to produce a JaxprRelation discovery: The compiler walks the Jaxpr to find ETP primitives, trace weight origins, and connect them to hidden states
Graph construction: The eligibility trace computation graph is built from the discovered relations
This compilation can be slow for complex models, especially on the first call. However:
Subsequent calls with the same input shapes reuse the compiled graph. The compilation result is cached, so you only pay the cost once.
compile_graph()should be called once before the training loop, not inside it. Calling it repeatedly with the same shapes is harmless (it detects the cache hit), but calling it inside a loop adds unnecessary overhead.
# Good: compile once, then run many steps
algo = braintrace.D_RTRL(model)
algo.compile_graph(example_input)
for step in range(num_steps):
output = algo(input_data[step]) # uses cached compilation
What CAN Be Used Inside update()#
The braintrace compiler works with all standard JAX mathematical operations that do not create sub-Jaxprs. These include:
Standard math operations:
jnp.add,jnp.subtract,jnp.multiply,jnp.divideElement-wise operators:
+,-,*,/
Matrix operations:
@(matrix multiply operator)jnp.dot,jnp.matmul,jnp.einsum
Activation functions:
jax.nn.tanh,jax.nn.relu,jax.nn.sigmoid,jax.nn.softmaxjax.nn.silu,jax.nn.gelu,jax.nn.leaky_relu
Shape manipulation:
jnp.reshape,jnp.transpose,jnp.concatenatejnp.expand_dims,jnp.squeeze
Selection and masking:
jax.lax.select(predicate, on_true, on_false)(preferred overjnp.whereinsideupdate(); see Workarounds above)jnp.clip,jnp.maximum,jnp.minimum
Gradient control:
jax.lax.stop_gradient– useful for detaching parts of the computation
braintrace ETP primitives:
braintrace.matmul– matrix multiplication with ETP trackingbraintrace.element_wise– element-wise parameter operations with ETP trackingbraintrace.conv– convolution with ETP trackingbraintrace.sparse_matmul– sparse matrix multiplication with ETP trackingbraintrace.lora_matmul– LoRA-style low-rank multiplication with ETP tracking
In general, if a JAX operation compiles to a flat sequence of primitives in the Jaxpr (no nested sub-Jaxprs), it is compatible with braintrace.
Summary#
The key limitations and their workarounds are:
Avoid
cond,scan,while_loop, and nestedvmapinside the model’supdate()method. These create sub-Jaxprs that the compiler cannot traverse. Use them freely outside the model (e.g., for time-step unrolling).Use
jnp.whereand masks as alternatives to conditional logic. Element-wise selection operations are fully supported and produce equivalent results for most use cases.Ensure shape compatibility between ETP primitive outputs and hidden states. The compiler filters connections by shape – if shapes don’t match or broadcast, the connection won’t be established.
Per-primitive ETP rules are local. A weight whose only path to a hidden state passes through another trainable ETP primitive is excluded with a
RELATION_EXCLUDED_WEIGHT_TO_WEIGHTwarning – it must be learned via BPTT or the architecture must be rewired.etp_elemwise_p(the onlygradient_enabled=Truebuilt-in) is the sole exception.Choose the right algorithm based on memory/accuracy tradeoffs. Use
D_RTRLfor exact gradients with moderate model sizes, andES_D_RTRLfor memory-efficient approximate gradients with larger models.Call
compile_graph()once before training, not inside the training loop. The compiled graph is cached and reused for inputs of the same shape.The compiler works with all standard JAX mathematical operations. As long as you avoid the unsupported control flow primitives listed above, your model will compile successfully.