ETPPrimitiveSpec#
- class braintrace.ETPPrimitiveSpec(name, impl, yw_to_w, xy_to_dw, init_drtrl, init_pp, trainable_invars_fn, x_invar_index=0, y_outvar_index=0, batched=False, gradient_enabled=False)[source]#
Declarative specification of an ETP primitive.
- name#
Primitive name (e.g.
'etp_mm').
- impl#
Implementation function. All standard JAX rules (abstract_eval, lowering, JVP, transpose, batching) are auto-derived from this.
- yw_to_w#
D-RTRL trace propagation rule.
- xy_to_dw#
Weight-gradient rule.
- init_drtrl#
D-RTRL parameter-dimension trace initialiser.
- init_pp#
pp_prop IO-dimension df trace initialiser.
- trainable_invars_fn#
Function
eqn.params -> {key: invar_index}declaring the primitive’s full trainable-input layout. Used by the compiler and executors to support N-trainable-input primitives (e.g.{weight, bias}for Linear,{B, A, bias}for LoRA).
- x_invar_index#
Position of the input
xineqn.invars, orNonefor primitives that have no external input (currently onlyetp_elemwise_p).
- y_outvar_index#
Position of the output
yineqn.outvars. All current primitives have a single output at index 0.
- batched#
Whether the primitive operates on batched inputs.
- gradient_enabled#
If True, the compiler may traverse this primitive when walking
y -> h(identity-like ops). If False (default for any trainable op), the primitive acts as a tail boundary — a preceding ETP weight whose only path tohpasses through this primitive is excluded from ETP.