Graph Compilation & Visualization#
In braintrace, models are compiled into an ETraceGraph – an intermediate representation that captures the
structural relationships between weight parameters, ETP primitives (the operations that connect inputs to
hidden states), and hidden state groups. This compilation step is what enables efficient online learning:
by analyzing the computation graph, braintrace can automatically determine which weights influence which
hidden states, and how eligibility traces should propagate.
The show_graph() method visualizes these relationships, providing a human-readable summary of:
Hidden groups: clusters of hidden states that evolve together (e.g., the membrane potential and adaptation current of a neuron population)
Weight-primitive-hidden connections: which weight parameters are associated with which hidden groups through which ETP primitives
Non-ETP weights: parameters that exist in the model but do not participate in online learning
Understanding the compiled graph is essential for debugging model structure, verifying that the correct parameters are included in online learning, and optimizing model design.
Single-Layer RNN#
We start with the simplest case: a single recurrent layer followed by a linear readout.
The ValinaRNNCell contains one hidden state and one recurrent weight, and the Linear
readout has its own weight that feeds into the output.
import jax
import jax.numpy as jnp
import brainstate
import braintrace
class SingleLayerRNN(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.rnn = braintrace.nn.ValinaRNNCell(n_in, n_rec)
self.out = braintrace.nn.Linear(n_rec, n_out)
def update(self, x):
return self.out(self.rnn(x))
model = SingleLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model)
# Compile the graph and visualize
algo = braintrace.D_RTRL(model)
algo.compile_graph(jnp.zeros(10))
algo.show_graph()
The output shows:
Hidden Group 0: the hidden state of the
ValinaRNNCell(path('rnn', 'h'))Weight 0: the recurrent weight inside the RNN cell, associated with Hidden Group 0
Weight 1: the readout weight, which may or may not appear depending on whether the readout layer uses ETP primitives
This tells us that D_RTRL will maintain an eligibility trace for the recurrent weight,
tracking how it influences the hidden state over time.
Understanding ETraceGraph#
The compiled graph is an ETraceGraph named tuple with several key fields:
Field |
Type |
Description |
|---|---|---|
|
|
Jaxpr and state mappings extracted from the model |
|
|
Discovered hidden state groups |
|
|
Mapping from hidden state path to its group |
|
|
Weight-primitive-hidden connections |
|
|
Perturbation structure for Jacobian computation |
Each HiddenGroup records a cluster of hidden states that are updated together in one
recurrent step. Each HiddenParamOpRelation records the connection between a weight
parameter and the hidden groups it feeds into through an ETP primitive.
Let’s inspect these programmatically:
graph = algo.graph
print("=== Hidden Groups ===")
for g in graph.hidden_groups:
print(f" Group {g.index}: {g.num_state} state(s), shape {g.varshape}")
print(f" Paths: {g.hidden_paths}")
print("\n=== Weight-Primitive-Hidden Relations ===")
for i, r in enumerate(graph.hidden_param_op_relations):
print(f" Relation {i}:")
print(f" Weight path: {r.weight_path}")
print(f" Primitive: {r.primitive}")
print(f" Hidden groups: {[g.index for g in r.hidden_groups]}")
print(f"\n=== Perturbation ===")
print(f" Has perturbation: {graph.hidden_perturb is not None}")
The HiddenGroup.num_state property returns the total number of state variables in the group,
and HiddenGroup.varshape returns the shape of each state variable. The
HiddenParamOpRelation.primitive field identifies which ETP primitive (e.g., etp_matmul_p)
connects the weight to the hidden state.
Two-Layer RNN#
With multiple recurrent layers, the graph becomes richer. Each layer introduces its own hidden group, and the compiler discovers which weights feed into which hidden groups. In a stacked RNN, each layer’s recurrent weight is associated with only its own hidden group – the layers are structurally independent from the perspective of eligibility trace propagation.
class TwoLayerRNN(brainstate.nn.Module):
def __init__(self, n_in, n_rec, n_out):
super().__init__()
self.rnn1 = braintrace.nn.GRUCell(n_in, n_rec)
self.rnn2 = braintrace.nn.GRUCell(n_rec, n_rec)
self.out = braintrace.nn.Linear(n_rec, n_out)
def update(self, x):
h1 = self.rnn1(x)
h2 = self.rnn2(h1)
return self.out(h2)
model2 = TwoLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model2)
algo2 = braintrace.D_RTRL(model2)
algo2.compile_graph(jnp.zeros(10))
algo2.show_graph()
Notice that:
Each GRU layer creates its own hidden group (the GRU hidden state
h)Each layer’s recurrent and input weights are associated with that layer’s hidden group
The readout weight forms its own relation if it uses an ETP primitive
This structural analysis is what allows D_RTRL to maintain separate eligibility traces
for each layer, avoiding the need to backpropagate through time across the entire network.
# Inspect the two-layer graph programmatically
graph2 = algo2.graph
print(f"Number of hidden groups: {len(graph2.hidden_groups)}")
print(f"Number of weight-hidden relations: {len(graph2.hidden_param_op_relations)}")
print("\nHidden groups:")
for g in graph2.hidden_groups:
print(f" Group {g.index}: {g.hidden_paths}")
print("\nRelations:")
for i, r in enumerate(graph2.hidden_param_op_relations):
groups = [g.index for g in r.hidden_groups]
print(f" Weight {i}: {r.weight_path} -> hidden group(s) {groups}")
Convolutional Network#
ETP primitives also support convolutional operations via braintrace.nn.Conv2d. When a
convolutional layer feeds into a recurrent layer, the compiler discovers the connection
between the convolution kernel and the downstream hidden state. This demonstrates the
generality of the graph compilation – it works with any ETP primitive, not just matrix
multiplication.
class ConvRNN(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.conv = braintrace.nn.Conv2d(1, 8, kernel_size=3, padding='SAME')
self.rnn = braintrace.nn.ValinaRNNCell(8 * 28 * 28, 64)
self.out = braintrace.nn.Linear(64, 10)
def update(self, x):
# x: (1, 28, 28) -- single-channel 28x28 image
features = self.conv(x).reshape(-1)
return self.out(self.rnn(features))
model3 = ConvRNN()
brainstate.nn.init_all_states(model3)
algo3 = braintrace.D_RTRL(model3)
algo3.compile_graph(jnp.zeros((1, 28, 28)))
algo3.show_graph()
In this model:
The
Conv2dkernel weight is discovered as an ETP parameter becausebraintrace.nn.Conv2duses theetp_convprimitive internallyThe RNN’s recurrent weight uses
etp_matmulBoth are associated with the RNN’s hidden group, since the convolution output flows into the recurrent computation
The readout
Linearlayer also uses an ETP primitive
This shows how the compiler traces data flow across different layer types to build the complete eligibility trace graph.
Using compile_etrace_graph Directly#
For advanced users who want to inspect the graph without wrapping the model in an algorithm
like D_RTRL, braintrace exposes the compile_etrace_graph() function directly. This
is useful for:
Debugging model structure before training
Verifying that ETP primitives are correctly placed
Building custom online learning algorithms on top of the graph
model_direct = SingleLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model_direct)
graph_direct = braintrace.compile_etrace_graph(model_direct, jnp.zeros(10))
print(f"Number of hidden groups: {len(graph_direct.hidden_groups)}")
print(f"Number of relations: {len(graph_direct.hidden_param_op_relations)}")
print(f"Has perturbation: {graph_direct.hidden_perturb is not None}")
print("\nGraph fields:")
for key in graph_direct.dict().keys():
print(f" {key}")
The compile_etrace_graph() function returns the same ETraceGraph named tuple that is
stored internally by D_RTRL and other algorithms. You can use it to build custom
training loops or to programmatically analyze model structure.
Summary#
In this tutorial, we covered the graph compilation and visualization tools in braintrace:
compile_graph()(on algorithm objects) andcompile_etrace_graph()(standalone function) analyze the model’s computation graph to discover the structural relationships between weights, ETP primitives, and hidden statesshow_graph()provides a human-readable summary of the compiled graph, showing hidden groups, weight-hidden associations, and non-ETP parametersThe compiled graph reveals which weights participate in online learning – only weights used through ETP primitives (
braintrace.nn.Linear,braintrace.nn.Conv2d, etc.) are includedMulti-layer and convolutional models create richer graph structures with multiple hidden groups and cross-layer relationships
The
ETraceGraphnamed tuple can be inspected programmatically for custom analysis or to build custom online learning algorithms
Understanding the compiled graph is a key step in verifying that your model is correctly structured
for online learning with braintrace.