ETPPrimitiveSpec#

class braintrace.ETPPrimitiveSpec(name, impl, yw_to_w, xy_to_dw, init_drtrl, init_pp, trainable_invars_fn, x_invar_index=0, y_outvar_index=0, batched=False, gradient_enabled=False)[source]#

Declarative specification of an ETP primitive.

name#

Primitive name (e.g. 'etp_mm').

impl#

Implementation function. All standard JAX rules (abstract_eval, lowering, JVP, transpose, batching) are auto-derived from this.

yw_to_w#

D-RTRL trace propagation rule.

xy_to_dw#

Weight-gradient rule.

init_drtrl#

D-RTRL parameter-dimension trace initialiser.

init_pp#

pp_prop IO-dimension df trace initialiser.

trainable_invars_fn#

Function eqn.params -> {key: invar_index} declaring the primitive’s full trainable-input layout. Used by the compiler and executors to support N-trainable-input primitives (e.g. {weight, bias} for Linear, {B, A, bias} for LoRA).

x_invar_index#

Position of the input x in eqn.invars, or None for primitives that have no external input (currently only etp_elemwise_p).

y_outvar_index#

Position of the output y in eqn.outvars. All current primitives have a single output at index 0.

batched#

Whether the primitive operates on batched inputs.

gradient_enabled#

If True, the compiler may traverse this primitive when walking y -> h (identity-like ops). If False (default for any trainable op), the primitive acts as a tail boundary — a preceding ETP weight whose only path to h passes through this primitive is excluded from ETP.

resolve_trainable_invars(eqn_params)[source]#

Return {key: invar_index} for this equation.

Return type:

Dict[str, int]