HiddenParamOpRelation#

class braintrace.HiddenParamOpRelation(primitive: Primitive, x_var: Var | None, y_var: Var, hidden_groups: List[HiddenGroup], y_to_hidden_group_jaxprs: List[Jaxpr], connected_hidden_paths: List[Tuple[str, ...]], eqn_params: dict, path_classification: Dict[Tuple[str, ...], str] = {}, trainable_vars: Dict[str, Var] = {}, trainable_paths: Dict[str, Tuple[str, ...]] = {}, trainable_leaf_indices: Dict[str, int] = {}, trainable_param_states: Dict[str, ParamState] = {}, trainable_processing_chains: Dict[str, Tuple[Primitive, ...]] = {})#

Connection between an ETP primitive, its trainable parameters, and hidden states.

Records the structural relationship:

\[h^t = f(y), \quad y = \text{primitive}(x, \theta)\]
primitive#

The JAX primitive (etp_mm_p, etp_mv_p, etc.).

x_var#

Jaxpr Var for the input (None for element-wise ops).

y_var#

Jaxpr Var for the primitive output.

hidden_groups#

Hidden groups that this op feeds into.

y_to_hidden_group_jaxprs#

Transition Jaxpr from y to each hidden group.

connected_hidden_paths#

Hidden-state paths connected to this op.

eqn_params#

Static parameters of the primitive equation.

path_classification#

{hidden_path: PathClassification.*} for each connected hidden state. Populated by the path-classification pass.

trainable_vars#

Per-key dict mapping a primitive-chosen key name (e.g. 'weight', 'bias', 'lora_b', 'lora_a') to its jaxpr Var. Populated by the compiler with one entry per declared trainable input.

trainable_paths#

Per-key dict mapping each key to the owning ParamState’s module path. When the primitive has two keys whose invars trace to the same ParamState (e.g. merged {weight, bias} Linear), the entries share a path.

trainable_leaf_indices#

Per-key dict mapping each key to the leaf index in jax.tree.leaves of the owning ParamState.

trainable_param_states#

Per-key dict mapping each key to the actual ParamState object.

trainable_processing_chains#

Per-key dict mapping each key to the backward-trace processing chain (primitives traversed from the trainable invar back to the originating ParamState invar).

connected_hidden_paths: List[Tuple[str, ...]]#

Alias for field number 5

eqn_params: dict#

Alias for field number 6

hidden_groups: List[HiddenGroup]#

Alias for field number 3

path_classification: Dict[Tuple[str, ...], str]#

Alias for field number 7

primitive: Primitive#

Alias for field number 0

trainable_leaf_indices: Dict[str, int]#

Alias for field number 10

trainable_param_states: Dict[str, ParamState]#

Alias for field number 11

trainable_paths: Dict[str, Tuple[str, ...]]#

Alias for field number 9

trainable_processing_chains: Dict[str, Tuple[Primitive, ...]]#

Alias for field number 12

trainable_vars: Dict[str, Var]#

Alias for field number 8

x_var: Var | None#

Alias for field number 1

y_to_hidden_group_jaxprs: List[Jaxpr]#

Alias for field number 4

y_to_hidden_groups(y_val, const_vals, concat_hidden_vals=True)[source]#

Evaluate transition jaxprs: y -> hidden group values.

y_var: Var#

Alias for field number 2