HiddenParamOpRelation#
- class braintrace.HiddenParamOpRelation(primitive: Primitive, x_var: Var | None, y_var: Var, hidden_groups: List[HiddenGroup], y_to_hidden_group_jaxprs: List[Jaxpr], connected_hidden_paths: List[Tuple[str, ...]], eqn_params: dict, path_classification: Dict[Tuple[str, ...], str] = {}, trainable_vars: Dict[str, Var] = {}, trainable_paths: Dict[str, Tuple[str, ...]] = {}, trainable_leaf_indices: Dict[str, int] = {}, trainable_param_states: Dict[str, ParamState] = {}, trainable_processing_chains: Dict[str, Tuple[Primitive, ...]] = {})#
Connection between an ETP primitive, its trainable parameters, and hidden states.
Records the structural relationship:
\[h^t = f(y), \quad y = \text{primitive}(x, \theta)\]- primitive#
The JAX primitive (
etp_mm_p,etp_mv_p, etc.).
- x_var#
Jaxpr
Varfor the input (Nonefor element-wise ops).
- y_var#
Jaxpr
Varfor the primitive output.
- hidden_groups#
Hidden groups that this op feeds into.
- y_to_hidden_group_jaxprs#
Transition Jaxpr from y to each hidden group.
- connected_hidden_paths#
Hidden-state paths connected to this op.
- eqn_params#
Static parameters of the primitive equation.
- path_classification#
{hidden_path: PathClassification.*}for each connected hidden state. Populated by the path-classification pass.
- trainable_vars#
Per-key dict mapping a primitive-chosen key name (e.g.
'weight','bias','lora_b','lora_a') to its jaxprVar. Populated by the compiler with one entry per declared trainable input.
- trainable_paths#
Per-key dict mapping each key to the owning
ParamState’s module path. When the primitive has two keys whose invars trace to the sameParamState(e.g. merged{weight, bias}Linear), the entries share a path.
- trainable_leaf_indices#
Per-key dict mapping each key to the leaf index in
jax.tree.leavesof the owningParamState.
- trainable_param_states#
Per-key dict mapping each key to the actual
ParamStateobject.
- trainable_processing_chains#
Per-key dict mapping each key to the backward-trace processing chain (primitives traversed from the trainable invar back to the originating
ParamStateinvar).
- eqn_params: dict#
Alias for field number 6
- hidden_groups: List[HiddenGroup]#
Alias for field number 3
- primitive: Primitive#
Alias for field number 0
- y_to_hidden_groups(y_val, const_vals, concat_hidden_vals=True)[source]#
Evaluate transition jaxprs: y -> hidden group values.
- y_var: Var#
Alias for field number 2