Compiler Internals#
The braintrace compiler transforms a brainstate.nn.Module into an ETraceGraph – a structured representation that captures all the relationships between weight parameters, hidden states, and ETP primitives needed for online learning.
The compilation pipeline consists of four stages:
extract_module_info– Trace the model and extract its Jaxpr (JAX’s intermediate representation)find_hidden_groups_from_minfo– Identify groups of recurrent hidden states that are mutually connectedfind_hidden_param_op_relations_from_minfo– Discover how ETP primitives connect weight parameters to hidden statesadd_hidden_perturbation_from_minfo– Build the perturbation structure for computing hidden-to-hidden Jacobians
Understanding these internals helps you debug compilation issues, inspect the computational graph, and customize behavior when working with non-standard model architectures.
Setup#
We will use a two-layer vanilla RNN as a running example throughout this notebook. This model is simple enough to inspect manually, yet complex enough to demonstrate multi-group hidden state discovery.
import jax
import jax.numpy as jnp
import brainstate
import braintrace
class TwoLayerRNN(brainstate.nn.Module):
"""A two-layer vanilla RNN with a linear readout."""
def __init__(self):
super().__init__()
self.rnn1 = braintrace.nn.ValinaRNNCell(10, 32)
self.rnn2 = braintrace.nn.ValinaRNNCell(32, 16)
self.out = braintrace.nn.Linear(16, 5)
def update(self, x):
return self.out(self.rnn2(self.rnn1(x)))
model = TwoLayerRNN()
brainstate.nn.init_all_states(model)
# A dummy input matching the first layer's input dimension
dummy_input = jnp.zeros(10)
print("Model created and states initialized.")
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Model created and states initialized.
Step 1: ModuleInfo Extraction#
extract_module_info(model, *args) is the first stage of the compiler. It:
Wraps the model in a
StatefulFunctionand traces it with JAX to produce a Jaxpr.Collects all states from the module hierarchy via
brainstate.graph.states(model).Classifies each state as a hidden state (
brainstate.HiddenState) or a weight parameter (brainstate.ParamState).Builds bidirectional mappings between Jaxpr variables and their module paths.
The result is a ModuleInfo named tuple containing the Jaxpr, state mappings, and variable-to-path dictionaries.
minfo = braintrace.extract_module_info(model, dummy_input)
print(f"Jaxpr equations: {len(minfo.jaxpr.eqns)}")
print(f"Compiled model states: {len(minfo.compiled_model_states)}")
print(f"Hidden states: {len(minfo.hidden_path_to_invar)}")
print(f"Weight parameters: {len(minfo.weight_path_to_invars)}")
Jaxpr equations: 12
Compiled model states: 5
Hidden states: 2
Weight parameters: 3
Inspecting State Mappings#
The ModuleInfo maintains separate mappings for hidden states and weight parameters. Each mapping connects a module path (a tuple of attribute names) to a Jaxpr variable.
print("=== Hidden State Paths ===")
for path, var in minfo.hidden_path_to_invar.items():
print(f" {path} -> {var}")
print()
print("=== Weight Parameter Paths ===")
for path, invars in minfo.weight_path_to_invars.items():
print(f" {path} -> {len(invars)} variable(s)")
=== Hidden State Paths ===
('rnn1', 'h') -> Var(id=132835946131648):float32[32]
('rnn2', 'h') -> Var(id=132832132601216):float32[16]
=== Weight Parameter Paths ===
('rnn1', 'W', 'weight') -> 2 variable(s)
('rnn2', 'W', 'weight') -> 2 variable(s)
('out', 'weight') -> 2 variable(s)
Step 3: Finding ETP Relations#
find_hidden_param_op_relations_from_minfo(minfo, hid_path_to_group) connects ETP primitives to their weight parameters and the hidden states they influence.
Algorithm:
For each equation in the Jaxpr:
Primitive identification: Check
eqn.primitive in ETP_PRIMITIVES(type identity, not string matching). This is robust – renaming a function or wrapping it injax.jitdoes not break identification.Weight extraction: Extract the weight variable from
eqn.invars(index 1 for matmul/conv, index 0 for element-wise).Backward tracing: Trace the weight variable backward through the Jaxpr (following producer equations) to find the originating
ParamState. This handles cases where weight transformations (e.g.,weight_fn, masking) are applied before the primitive.Forward BFS: From the primitive’s output variable, perform a breadth-first search forward through the Jaxpr to find reachable hidden-state outvars.
Shape compatibility: Filter out hidden outvars whose shapes are not broadcast-compatible with the primitive output.
Transition Jaxpr: Build a sub-Jaxpr mapping
y -> hfor each connected hidden group, used later for computingdh/dy.
The result is a sequence of HiddenParamOpRelation named tuples.
relations = braintrace.find_hidden_param_op_relations_from_minfo(minfo, hid_path_to_group)
print(f"Number of ETP relations discovered: {len(relations)}")
print()
for i, r in enumerate(relations):
print(f"Relation {i}:")
print(f" Primitive: {r.primitive.name}")
print(f" Weight path: {r.weight_path}")
print(f" x_var: {r.x_var}")
print(f" y_var: {r.y_var}")
print(f" Connected hidden groups: {[g.index for g in r.hidden_groups]}")
print(f" Connected hidden paths:")
for path in r.connected_hidden_paths:
print(f" - {path}")
print(f" Equation params: {r.eqn_params}")
print()
Number of ETP relations discovered: 2
Relation 0:
Primitive: etp_mv
Weight path: ('rnn1', 'W', 'weight')
x_var: Var(id=132831597789632):float32[42]
y_var: Var(id=132835947956608):float32[32]
Connected hidden groups: [0]
Connected hidden paths:
- ('rnn1', 'h')
Equation params: {'has_bias': True}
Relation 1:
Primitive: etp_mv
Weight path: ('rnn2', 'W', 'weight')
x_var: Var(id=132831595499392):float32[48]
y_var: Var(id=132831595499648):float32[16]
Connected hidden groups: [1]
Connected hidden paths:
- ('rnn2', 'h')
Equation params: {'has_bias': True}
/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.
_emit_no_relation_diag(
The Complete Pipeline#
compile_etrace_graph(model, *args) runs all four steps in sequence and returns an ETraceGraph named tuple containing everything the online learning algorithms need.
The function also performs an additional step: it rewrites the Jaxpr to return extra intermediate variables (weight inputs, transition constants, etc.) that the graph executor needs at runtime.
graph = braintrace.compile_etrace_graph(model, jnp.zeros(10))
print(f"Hidden groups: {len(graph.hidden_groups)}")
print(f"ETP relations: {len(graph.hidden_param_op_relations)}")
print(f"Module info available: {graph.module_info is not None}")
print(f"Hidden perturbation available: {graph.hidden_perturb is not None}")
print()
print(f"ETraceGraph fields: {list(graph._asdict().keys())}")
Hidden groups: 2
ETP relations: 2
Module info available: True
Hidden perturbation available: True
ETraceGraph fields: ['module_info', 'hidden_groups', 'hid_path_to_group', 'hidden_param_op_relations', 'hidden_perturb', 'diagnostics']
How Primitive Identification Works#
A key design decision in braintrace is type-based primitive identification rather than name-based matching.
Old system (string matching)#
The old ETraceOp system identified weight operations by matching JIT function names as strings. This was fragile:
Renaming a function broke the match.
Wrapping a function in
jax.jitchanged the name.Name collisions could cause false matches.
New system (type identity)#
The ETP system registers custom JAX primitives (etp_mm_p, etp_mv_p, etp_elemwise_p, etp_conv_p, etp_sp_mm_p, etp_sp_mv_p, etp_lora_mm_p, etp_lora_mv_p) and identifies them by checking:
eqn.primitive in ETP_PRIMITIVES # set membership, O(1)
This is robust because:
Primitive identity is a Python object reference, not a string.
Wrapping in
jax.jitdoes not change the primitive type.No name collisions are possible.
Weight variable extraction#
Once an ETP equation is found, the weight variable is extracted from eqn.invars at a position that depends on the primitive type. The same indices are recorded by ETPPrimitiveSpec (see weight_invar_index / x_invar_index) and can be queried at runtime through braintrace.get_primitive_spec(prim):
Primitive |
|
|
|
Notes |
|---|---|---|---|---|
|
input |
weight |
bias |
|
|
processed weight |
– |
– |
|
|
input |
kernel |
bias |
|
|
input |
sparse |
bias |
|
|
input |
LoRA factor |
LoRA factor |
|
The weight variable is then traced backward through the Jaxpr’s producer map to find the originating ParamState, handling intermediate transformations like masking, weight standardization, or sign constraints.
Debugging Compilation Issues#
When compilation produces unexpected results (e.g., missing relations, wrong group assignments), inspecting the raw Jaxpr is the most effective debugging tool.
Common issues#
Missing ETP relations: A weight parameter uses a regular JAX op (e.g.,
x @ w) instead of an ETP primitive (e.g.,braintrace.matmul(x, w)). The compiler only recognizes ETP primitives.Shape mismatches: The output of an ETP primitive is not broadcast-compatible with the target hidden state. The compiler will warn and skip the connection.
Hidden states in control flow: Hidden states computed inside
jax.lax.scan,jax.lax.while_loop, orjax.lax.condare currently unsupported and will raise an error.
Inspecting the Jaxpr#
You can iterate over the Jaxpr equations and flag ETP primitives to verify the compiler sees what you expect.
from braintrace._etrace_op import is_etp_primitive
print("Jaxpr equations (ETP primitives marked with **):\n")
for i, eqn in enumerate(minfo.jaxpr.eqns):
primitive_name = eqn.primitive.name
in_shapes = [
v.aval.shape if hasattr(v, 'aval') else 'literal'
for v in eqn.invars
]
out_shapes = [v.aval.shape for v in eqn.outvars]
is_etp = is_etp_primitive(eqn.primitive)
marker = "**ETP**" if is_etp else " "
print(f" [{i:2d}] {marker} {primitive_name}: {in_shapes} -> {out_shapes}")
Jaxpr equations (ETP primitives marked with **):
[ 0] convert_element_type: [(32,)] -> [(32,)]
[ 1] concatenate: [(10,), (32,)] -> [(42,)]
[ 2] **ETP** etp_mv: [(42,), (42, 32), (32,)] -> [(32,)]
[ 3] mul: [(32,), ()] -> [(32,)]
[ 4] custom_jvp_call: [(32,)] -> [(32,)]
[ 5] convert_element_type: [(16,)] -> [(16,)]
[ 6] concatenate: [(32,), (16,)] -> [(48,)]
[ 7] **ETP** etp_mv: [(48,), (48, 16), (16,)] -> [(16,)]
[ 8] mul: [(16,), ()] -> [(16,)]
[ 9] custom_jvp_call: [(16,)] -> [(16,)]
[10] **ETP** etp_mv: [(16,), (16, 5), (5,)] -> [(5,)]
[11] mul: [(5,), ()] -> [(5,)]
Verifying backward tracing#
If an ETP relation is missing, you can manually check whether the weight variable can be traced back to a ParamState by building the producer map and calling the internal tracing function.
from braintrace._etrace_compiler.hid_param_op import _build_producer_map, _trace_var_to_param
from braintrace import get_primitive_spec
producers = _build_producer_map(minfo.jaxpr)
for eqn in minfo.jaxpr.eqns:
if not is_etp_primitive(eqn.primitive):
continue
# Look up the spec rather than hard-coding the weight index per primitive.
spec = get_primitive_spec(eqn.primitive)
weight_var = eqn.invars[spec.weight_invar_index]
path = _trace_var_to_param(
weight_var, producers, minfo.invar_to_weight_path
)
print(f"Primitive: {eqn.primitive.name}")
print(f" weight_invar_index: {spec.weight_invar_index}")
print(f" Weight var: {weight_var}")
print(f" Traced to ParamState: {path}")
print()
Primitive: etp_mv
weight_invar_index: 1
Weight var: Var(id=132831598399488):float32[42,32]
Traced to ParamState: (('rnn1', 'W', 'weight'), ())
Primitive: etp_mv
weight_invar_index: 1
Weight var: Var(id=132831595498880):float32[48,16]
Traced to ParamState: (('rnn2', 'W', 'weight'), ())
Primitive: etp_mv
weight_invar_index: 1
Weight var: Var(id=132831595549632):float32[16,5]
Traced to ParamState: (('out', 'weight'), ())
Compiler Diagnostics#
Every call to compile_etrace_graph annotates each weight/primitive decision with a CompilationRecord. The full list lives at graph.diagnostics. When the compiler skips a weight or merges hidden groups in a way you did not expect, this list is the first place to look – it tells you which weight, which primitive, and why.
Each CompilationRecord has these fields:
Field |
Type |
What it carries |
|---|---|---|
|
|
Decision category, e.g. |
|
|
|
|
|
Human-readable summary, including the weight path and primitive name. |
|
|
The ETP primitive involved (if any). |
|
|
Dotted path to the |
|
|
Hidden-state paths the relation reaches. |
|
|
Free-form extra info – group indices, classification tags, etc. |
The same diagnostics are emitted as UserWarnings during compile_graph() (so you see them on stderr without doing anything), but querying graph.diagnostics lets you filter, log, or assert on them programmatically. graph.explain() is a convenience that prints the records grouped by kind.
from braintrace import DiagnosticKind, DiagnosticLevel
print(f"Diagnostics: {len(graph.diagnostics)}")
for d in graph.diagnostics:
print(f" [{d.level.name:7s}] {d.kind.name}: {d.message}")
# Common queries -- did anything error? was any weight excluded as a tail boundary?
errors = [d for d in graph.diagnostics if d.level == DiagnosticLevel.ERROR]
weight_to_weight = [
d for d in graph.diagnostics
if d.kind == DiagnosticKind.RELATION_EXCLUDED_WEIGHT_TO_WEIGHT
]
print(f"\nerrors: {len(errors)}, weight->weight exclusions: {len(weight_to_weight)}")
# `graph.explain()` prints the same information grouped by kind for quick scanning.
graph.explain()
Diagnostics: 3
[INFO ] RELATION_INCLUDED: etp_mv(('rnn1', 'W', 'weight')) -> [0]
[INFO ] RELATION_INCLUDED: etp_mv(('rnn2', 'W', 'weight')) -> [1]
[WARNING] RELATION_EXCLUDED_NON_TEMPORAL: ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.
errors: 0, weight->weight exclusions: 0
(CompilationRecord(kind=relation_included, level=info, primitive='etp_mv', weight_path=('rnn1', 'W', 'weight'), hidden_paths=[('rnn1', 'h')], message="etp_mv(('rnn1', 'W', 'weight')) -> [0]", context={'hidden_group_indices': (0,), 'path_classification': {('rnn1', 'h'): 'all_direct'}}),
CompilationRecord(kind=relation_included, level=info, primitive='etp_mv', weight_path=('rnn2', 'W', 'weight'), hidden_paths=[('rnn2', 'h')], message="etp_mv(('rnn2', 'W', 'weight')) -> [1]", context={'hidden_group_indices': (1,), 'path_classification': {('rnn2', 'h'): 'all_direct'}}),
CompilationRecord(kind=relation_excluded_non_temporal, level=warning, primitive='etp_mv', weight_path=('out', 'weight'), message="ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter."))
Summary#
The braintrace compiler follows a 4-step pipeline to transform a neural network module into an optimized graph for online learning. The same pipeline produces a CompilationRecord you can inspect via graph.diagnostics (or graph.explain()) to debug missing or mis-placed relations.
Step |
Function |
Output |
Purpose |
|---|---|---|---|
1 |
|
|
Trace model, extract Jaxpr, classify states |
2 |
|
|
Identify connected recurrent state groups |
3 |
|
|
Connect ETP primitives to weights and hidden states |
4 |
|
|
Build perturbation Jaxpr for Jacobian computation |
Key design decisions:
Type-based primitive identification (
eqn.primitive in ETP_PRIMITIVES) is robust and extensible, replacing the old fragile string-matching approach.Backward tracing from weight variables to
ParamStatehandles weight transformations transparently.Forward BFS from primitive outputs to hidden outvars with shape compatibility filtering ensures correct connectivity.
Perturbation rewriting of the Jaxpr enables efficient hidden-to-hidden Jacobian computation via automatic differentiation.
Understanding this pipeline is essential for debugging compilation failures, extending braintrace with new primitives, and reasoning about the structure of online learning graphs.