ModuleInfo#
- class braintrace.ModuleInfo(stateful_model: StatefulFunction, closed_jaxpr: ClosedJaxpr, retrieved_model_states: FlattedDict[Tuple[str, ...], State], compiled_model_states: Sequence[State], state_id_to_path: Dict[int, Tuple[str, ...]], state_tree_invars: PyTree[jax._src.core.Var], state_tree_outvars: PyTree[jax._src.core.Var], hidden_path_to_invar: Dict[Tuple[str, ...], Var], hidden_path_to_outvar: Dict[Tuple[str, ...], Var], invar_to_hidden_path: Dict[Var, Tuple[str, ...]], outvar_to_hidden_path: Dict[Var, Tuple[str, ...]], hidden_outvar_to_invar: Dict[Var, Var], weight_invars: List[Var], weight_path_to_invars: Dict[Tuple[str, ...], List[Var]], invar_to_weight_path: Dict[Var, Tuple[str, ...]], num_var_out: int, num_var_state: int)#
The model information for the etrace compiler.
The model information contains the at least five categories of information:
The stateful model.
stateful_model: The stateful model is the model that compiles the model into abstract jaxpr representation.
The jaxpr.
The jaxpr is the abstract representation of the model.
closed_jaxpr: The closed jaxpr is the closed jaxpr representation of the model.
The states.
retrieved_model_states: The model states that are retrieved from themodel.states()function, which has well-defined paths and structures.compiled_model_states: The model states that are compiled from the stateful model, which is accurate and consistent with the model jaxpr, but loss the path information.state_id_to_path: The mapping from the state id to the state path.
The hidden states.
hidden_path_to_invar: The mapping from the hidden path to the input variable.hidden_path_to_outvar: The mapping from the hidden path to the output variable.invar_to_hidden_path: The mapping from the input variable to the hidden path.outvar_to_hidden_path: The mapping from the output variable to the hidden path.hidden_outvar_to_invar: The mapping from the output variable to the input variable.
The parameter weights.
weight_invars: The weight input variables.weight_path_to_invars: The mapping from the weight path to the input variables.invar_to_weight_path: The mapping from the input variable to the weight path.
Example:
>>> import braintrace >>> import brainstate >>> gru = braintrace.nn.GRUCell(10, 20) >>> gru.init_state() >>> inputs = brainstate.random.randn(10) >>> module_info = braintrace.extract_module_info(gru, inputs)
- add_jaxpr_outs(jax_vars)[source]#
Adding the jaxpr outputs to the model jaxpr, so that it can return the additional variables which needed for the etrace compiler.
- Return type:
- closed_jaxpr: ClosedJaxpr#
Alias for field number 1
Alias for field number 11
Alias for field number 7
Alias for field number 8
Alias for field number 9
- property jaxpr: Jaxpr#
The jaxpr of the model.
- jaxpr_call(*args, old_state_vals=None)[source]#
Computing the model according to the given inputs and parameters by using the compiled jaxpr.
Alias for field number 10
- split_state_outvars()[source]#
Splitting the state outvars into three parts: weight, hidden, and other states.
- Returns:
The weight tree of jax Var. hidden_jaxvar: The hidden tree of jax Var. other_state_jaxvar_tree: The other state tree of jax Var.
- Return type:
weight_jaxvar_tree
- state_tree_invars: PyTree[jax._src.core.Var]#
Alias for field number 5
- state_tree_outvars: PyTree[jax._src.core.Var]#
Alias for field number 6
- stateful_model: StatefulFunction#
Alias for field number 0