Graph Compilation & Visualization#

In braintrace, models are compiled into an ETraceGraph – an intermediate representation that captures the structural relationships between weight parameters, ETP primitives (the operations that connect inputs to hidden states), and hidden state groups. This compilation step is what enables efficient online learning: by analyzing the computation graph, braintrace can automatically determine which weights influence which hidden states, and how eligibility traces should propagate.

The show_graph() method visualizes these relationships, providing a human-readable summary of:

  • Hidden groups: clusters of hidden states that evolve together (e.g., the membrane potential and adaptation current of a neuron population)

  • Weight-primitive-hidden connections: which weight parameters are associated with which hidden groups through which ETP primitives

  • Non-ETP weights: parameters that exist in the model but do not participate in online learning

Understanding the compiled graph is essential for debugging model structure, verifying that the correct parameters are included in online learning, and optimizing model design.

Single-Layer RNN#

We start with the simplest case: a single recurrent layer followed by a linear readout. The ValinaRNNCell contains one hidden state and one recurrent weight, and the Linear readout has its own weight that feeds into the output.

import jax
import jax.numpy as jnp
import brainstate
import braintrace
class SingleLayerRNN(brainstate.nn.Module):
    def __init__(self, n_in, n_rec, n_out):
        super().__init__()
        self.rnn = braintrace.nn.ValinaRNNCell(n_in, n_rec)
        self.out = braintrace.nn.Linear(n_rec, n_out)

    def update(self, x):
        return self.out(self.rnn(x))


model = SingleLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model)

# Compile the graph and visualize
algo = braintrace.D_RTRL(model)
algo.compile_graph(jnp.zeros(10))
algo.show_graph()

The output shows:

  • Hidden Group 0: the hidden state of the ValinaRNNCell (path ('rnn', 'h'))

  • Weight 0: the recurrent weight inside the RNN cell, associated with Hidden Group 0

  • Weight 1: the readout weight, which may or may not appear depending on whether the readout layer uses ETP primitives

This tells us that D_RTRL will maintain an eligibility trace for the recurrent weight, tracking how it influences the hidden state over time.

Understanding ETraceGraph#

The compiled graph is an ETraceGraph named tuple with several key fields:

Field

Type

Description

module_info

ModuleInfo

Jaxpr and state mappings extracted from the model

hidden_groups

Sequence[HiddenGroup]

Discovered hidden state groups

hid_path_to_group

Dict[Path, HiddenGroup]

Mapping from hidden state path to its group

hidden_param_op_relations

Sequence[HiddenParamOpRelation]

Weight-primitive-hidden connections

hidden_perturb

HiddenPerturbation or None

Perturbation structure for Jacobian computation

Each HiddenGroup records a cluster of hidden states that are updated together in one recurrent step. Each HiddenParamOpRelation records the connection between a weight parameter and the hidden groups it feeds into through an ETP primitive.

Let’s inspect these programmatically:

graph = algo.graph

print("=== Hidden Groups ===")
for g in graph.hidden_groups:
    print(f"  Group {g.index}: {g.num_state} state(s), shape {g.varshape}")
    print(f"    Paths: {g.hidden_paths}")

print("\n=== Weight-Primitive-Hidden Relations ===")
for i, r in enumerate(graph.hidden_param_op_relations):
    print(f"  Relation {i}:")
    print(f"    Weight path: {r.weight_path}")
    print(f"    Primitive: {r.primitive}")
    print(f"    Hidden groups: {[g.index for g in r.hidden_groups]}")

print(f"\n=== Perturbation ===")
print(f"  Has perturbation: {graph.hidden_perturb is not None}")

The HiddenGroup.num_state property returns the total number of state variables in the group, and HiddenGroup.varshape returns the shape of each state variable. The HiddenParamOpRelation.primitive field identifies which ETP primitive (e.g., etp_matmul_p) connects the weight to the hidden state.

Two-Layer RNN#

With multiple recurrent layers, the graph becomes richer. Each layer introduces its own hidden group, and the compiler discovers which weights feed into which hidden groups. In a stacked RNN, each layer’s recurrent weight is associated with only its own hidden group – the layers are structurally independent from the perspective of eligibility trace propagation.

class TwoLayerRNN(brainstate.nn.Module):
    def __init__(self, n_in, n_rec, n_out):
        super().__init__()
        self.rnn1 = braintrace.nn.GRUCell(n_in, n_rec)
        self.rnn2 = braintrace.nn.GRUCell(n_rec, n_rec)
        self.out = braintrace.nn.Linear(n_rec, n_out)

    def update(self, x):
        h1 = self.rnn1(x)
        h2 = self.rnn2(h1)
        return self.out(h2)


model2 = TwoLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model2)

algo2 = braintrace.D_RTRL(model2)
algo2.compile_graph(jnp.zeros(10))
algo2.show_graph()

Notice that:

  • Each GRU layer creates its own hidden group (the GRU hidden state h)

  • Each layer’s recurrent and input weights are associated with that layer’s hidden group

  • The readout weight forms its own relation if it uses an ETP primitive

This structural analysis is what allows D_RTRL to maintain separate eligibility traces for each layer, avoiding the need to backpropagate through time across the entire network.

# Inspect the two-layer graph programmatically
graph2 = algo2.graph

print(f"Number of hidden groups: {len(graph2.hidden_groups)}")
print(f"Number of weight-hidden relations: {len(graph2.hidden_param_op_relations)}")

print("\nHidden groups:")
for g in graph2.hidden_groups:
    print(f"  Group {g.index}: {g.hidden_paths}")

print("\nRelations:")
for i, r in enumerate(graph2.hidden_param_op_relations):
    groups = [g.index for g in r.hidden_groups]
    print(f"  Weight {i}: {r.weight_path} -> hidden group(s) {groups}")

Convolutional Network#

ETP primitives also support convolutional operations via braintrace.nn.Conv2d. When a convolutional layer feeds into a recurrent layer, the compiler discovers the connection between the convolution kernel and the downstream hidden state. This demonstrates the generality of the graph compilation – it works with any ETP primitive, not just matrix multiplication.

class ConvRNN(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = braintrace.nn.Conv2d(1, 8, kernel_size=3, padding='SAME')
        self.rnn = braintrace.nn.ValinaRNNCell(8 * 28 * 28, 64)
        self.out = braintrace.nn.Linear(64, 10)

    def update(self, x):
        # x: (1, 28, 28) -- single-channel 28x28 image
        features = self.conv(x).reshape(-1)
        return self.out(self.rnn(features))


model3 = ConvRNN()
brainstate.nn.init_all_states(model3)

algo3 = braintrace.D_RTRL(model3)
algo3.compile_graph(jnp.zeros((1, 28, 28)))
algo3.show_graph()

In this model:

  • The Conv2d kernel weight is discovered as an ETP parameter because braintrace.nn.Conv2d uses the etp_conv primitive internally

  • The RNN’s recurrent weight uses etp_matmul

  • Both are associated with the RNN’s hidden group, since the convolution output flows into the recurrent computation

  • The readout Linear layer also uses an ETP primitive

This shows how the compiler traces data flow across different layer types to build the complete eligibility trace graph.

Using compile_etrace_graph Directly#

For advanced users who want to inspect the graph without wrapping the model in an algorithm like D_RTRL, braintrace exposes the compile_etrace_graph() function directly. This is useful for:

  • Debugging model structure before training

  • Verifying that ETP primitives are correctly placed

  • Building custom online learning algorithms on top of the graph

model_direct = SingleLayerRNN(10, 32, 5)
brainstate.nn.init_all_states(model_direct)

graph_direct = braintrace.compile_etrace_graph(model_direct, jnp.zeros(10))

print(f"Number of hidden groups: {len(graph_direct.hidden_groups)}")
print(f"Number of relations: {len(graph_direct.hidden_param_op_relations)}")
print(f"Has perturbation: {graph_direct.hidden_perturb is not None}")

print("\nGraph fields:")
for key in graph_direct.dict().keys():
    print(f"  {key}")

The compile_etrace_graph() function returns the same ETraceGraph named tuple that is stored internally by D_RTRL and other algorithms. You can use it to build custom training loops or to programmatically analyze model structure.

Summary#

In this tutorial, we covered the graph compilation and visualization tools in braintrace:

  • compile_graph() (on algorithm objects) and compile_etrace_graph() (standalone function) analyze the model’s computation graph to discover the structural relationships between weights, ETP primitives, and hidden states

  • show_graph() provides a human-readable summary of the compiled graph, showing hidden groups, weight-hidden associations, and non-ETP parameters

  • The compiled graph reveals which weights participate in online learning – only weights used through ETP primitives (braintrace.nn.Linear, braintrace.nn.Conv2d, etc.) are included

  • Multi-layer and convolutional models create richer graph structures with multiple hidden groups and cross-layer relationships

  • The ETraceGraph named tuple can be inspected programmatically for custom analysis or to build custom online learning algorithms

Understanding the compiled graph is a key step in verifying that your model is correctly structured for online learning with braintrace.