ModuleInfo#

class braintrace.ModuleInfo(stateful_model: StatefulFunction, closed_jaxpr: ClosedJaxpr, retrieved_model_states: FlattedDict[Tuple[str, ...], State], compiled_model_states: Sequence[State], state_id_to_path: Dict[int, Tuple[str, ...]], state_tree_invars: PyTree[jax._src.core.Var], state_tree_outvars: PyTree[jax._src.core.Var], hidden_path_to_invar: Dict[Tuple[str, ...], Var], hidden_path_to_outvar: Dict[Tuple[str, ...], Var], invar_to_hidden_path: Dict[Var, Tuple[str, ...]], outvar_to_hidden_path: Dict[Var, Tuple[str, ...]], hidden_outvar_to_invar: Dict[Var, Var], weight_invars: List[Var], weight_path_to_invars: Dict[Tuple[str, ...], List[Var]], invar_to_weight_path: Dict[Var, Tuple[str, ...]], num_var_out: int, num_var_state: int)#

The model information for the etrace compiler.

The model information contains the at least five categories of information:

  1. The stateful model.

    • stateful_model: The stateful model is the model that compiles the model into abstract jaxpr representation.

  2. The jaxpr.

    The jaxpr is the abstract representation of the model.

    • closed_jaxpr: The closed jaxpr is the closed jaxpr representation of the model.

  3. The states.

    • retrieved_model_states: The model states that are retrieved from the model.states() function, which has well-defined paths and structures.

    • compiled_model_states: The model states that are compiled from the stateful model, which is accurate and consistent with the model jaxpr, but loss the path information.

    • state_id_to_path: The mapping from the state id to the state path.

  4. The hidden states.

    • hidden_path_to_invar: The mapping from the hidden path to the input variable.

    • hidden_path_to_outvar: The mapping from the hidden path to the output variable.

    • invar_to_hidden_path: The mapping from the input variable to the hidden path.

    • outvar_to_hidden_path: The mapping from the output variable to the hidden path.

    • hidden_outvar_to_invar: The mapping from the output variable to the input variable.

  5. The parameter weights.

    • weight_invars: The weight input variables.

    • weight_path_to_invars: The mapping from the weight path to the input variables.

    • invar_to_weight_path: The mapping from the input variable to the weight path.

Example:

>>> import braintrace
>>> import brainstate
>>> gru = braintrace.nn.GRUCell(10, 20)
>>> gru.init_state()
>>> inputs = brainstate.random.randn(10)
>>> module_info = braintrace.extract_module_info(gru, inputs)
add_jaxpr_outs(jax_vars)[source]#

Adding the jaxpr outputs to the model jaxpr, so that it can return the additional variables which needed for the etrace compiler.

Return type:

ModuleInfo

closed_jaxpr: ClosedJaxpr#

Alias for field number 1

compiled_model_states: Sequence[State]#

Alias for field number 3

hidden_outvar_to_invar: Dict[Var, Var]#

Alias for field number 11

hidden_path_to_invar: Dict[Tuple[str, ...], Var]#

Alias for field number 7

hidden_path_to_outvar: Dict[Tuple[str, ...], Var]#

Alias for field number 8

invar_to_hidden_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 9

invar_to_weight_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 14

property jaxpr: Jaxpr#

The jaxpr of the model.

jaxpr_call(*args, old_state_vals=None)[source]#

Computing the model according to the given inputs and parameters by using the compiled jaxpr.

Parameters:
  • args (Any) – The inputs of the model.

  • old_state_vals (Sequence[Array] | None) – The old state values.

Returns:

The output of the model. etrace_vals: The values for etrace states. oth_state_vals: The other state values. temps: The temporary intermediate values.

Return type:

Tuple[Any, Dict[Tuple[str, ...], Any], Dict[Tuple[str, ...], Any], Dict[Var, Array]]

num_var_out: int#

Alias for field number 15

num_var_state: int#

Alias for field number 16

outvar_to_hidden_path: Dict[Var, Tuple[str, ...]]#

Alias for field number 10

retrieved_model_states: FlattedDict[Tuple[str, ...], State]#

Alias for field number 2

split_state_outvars()[source]#

Splitting the state outvars into three parts: weight, hidden, and other states.

Returns:

The weight tree of jax Var. hidden_jaxvar: The hidden tree of jax Var. other_state_jaxvar_tree: The other state tree of jax Var.

Return type:

weight_jaxvar_tree

state_id_to_path: Dict[int, Tuple[str, ...]]#

Alias for field number 4

state_tree_invars: PyTree[jax._src.core.Var]#

Alias for field number 5

state_tree_outvars: PyTree[jax._src.core.Var]#

Alias for field number 6

stateful_model: StatefulFunction#

Alias for field number 0

weight_invars: List[Var]#

Alias for field number 12

weight_path_to_invars: Dict[Tuple[str, ...], List[Var]]#

Alias for field number 13