Compiler#

The compiler analyzes the model’s JAX intermediate representation (Jaxpr) to discover relationships between ETP primitives, weight parameters, and hidden states. This page documents the compiler pipeline and its data structures.

Graph Compilation#

The main entry point for compiling a model into an eligibility trace graph.

compile_etrace_graph

Constructs the eligibility trace graph for a given model based on the provided inputs.

ETraceGraph

The overall compiled graph for the eligibility trace.

Module Info#

Extracts the Jaxpr and state information from a brainstate.nn.Module.

extract_module_info

Extracting the model information for the etrace compiler.

ModuleInfo

The model information for the etrace compiler.

Hidden Groups#

Groups of hidden states that are updated together in the recurrent computation.

HiddenGroup

The data structure for recording the hidden group relation.

find_hidden_groups_from_minfo

Finding the hidden groups from the model.

find_hidden_groups_from_module

Find hidden groups from the model.

Hidden-Parameter-Operation Relations#

Connections between ETP primitives, weight parameters, and hidden states. Each relation describes: “weight W is used through ETP primitive P, and the output feeds into hidden group H.”

HiddenParamOpRelation

Connection between an ETP primitive, its trainable parameters, and hidden states.

find_hidden_param_op_relations_from_minfo

Find ETP relations from a ModuleInfo.

find_hidden_param_op_relations_from_module

Find ETP relations from a model.

Hidden Perturbation#

Perturbation structures for computing hidden-to-hidden Jacobians (the diagonal approximation of \(\partial h^t / \partial h^{t-1}\)).

HiddenPerturbation

The hidden perturbation information.

add_hidden_perturbation_from_minfo

Adding perturbations to the hidden states in the module, and replacing the hidden states with the perturbed states.

add_hidden_perturbation_in_module

Adds perturbations to the hidden states in the given module and replaces the hidden states with the perturbed states.

Graph Executor#

Executes the compiled graph: runs the forward pass and computes the hidden-to-weight and hidden-to-hidden Jacobians.

ETraceGraphExecutor

The eligibility trace graph executor.

ETraceVjpGraphExecutor

The eligibility trace graph executor for the VJP-based online learning algorithms.