Custom Primitives#
This page documents how to register a new ETP primitive and the four
ETP-specific rules every primitive must provide. The built-in primitives
(etp_mm, etp_mv, etp_conv, etp_elemwise, etp_sp_mm,
etp_sp_mv, etp_lora_mm, etp_lora_mv) are themselves registered
through this same machinery — adding a custom op uses the same surface.
Two registration styles#
There are two equivalent ways to register a primitive. Pick whichever fits your codebase:
Class-based — call
register_primitive()to obtain anETPPrimitive, then attach the four rules via itsregister_*methods. Best for incremental development where you want to register rules close to where they are defined.Spec-based — declare an
ETPPrimitiveSpec(a single dataclass-like value), then pass it toregister_primitive_spec(). Best when you want a single object that fully describes the primitive (useful for testing and for the compiler’sget_primitive_specquery).
Both styles populate the same four global registries
(ETP_RULES_YW_TO_W, ETP_RULES_XY_TO_DW,
ETP_RULES_INIT_DRTRL, ETP_RULES_INIT_PP) and result in a
fully-functional ETPPrimitive.
Registration entry points#
Create an |
|
Create an |
|
Return the |
Primitive class & spec#
A JAX |
|
Declarative specification of an ETP primitive. |
The four ETP rules#
Every ETP primitive must supply four rules. They are stored in four
global dict registries keyed by primitive — the compiler and the
online-learning algorithms look up the rule for a primitive at compile
time.
Registry |
Signature |
Purpose |
|---|---|---|
|
|
D-RTRL trace propagation: combine an upstream hidden-state Jacobian factor with the trace through the current weight. |
|
|
Weight-gradient rule: produce |
|
|
D-RTRL trace initialiser. Returns a zero array (or pytree of arrays) shaped to hold the parameter-dim trace. |
|
|
pp_prop / ES-D-RTRL trace initialiser. Returns a zero array shaped to hold the IO-dim trace. |
The four registries live in braintrace._etrace_op and are
populated by both registration styles.
Class-based example#
import jax
import jax.numpy as jnp
import braintrace
def _my_impl(x, w, *, scale=1.0):
return scale * (x @ w)
my_p = braintrace.register_primitive('etp_my_op', _my_impl, batched=True)
# Rules can be registered one-by-one, or in a single call via
# ``register_etp_rules(yw_to_w=..., xy_to_dw=..., ...)``.
my_p.register_yw_to_w(
lambda hidden, trace, **params: trace * hidden[None, :]
)
my_p.register_xy_to_dw(
lambda x, hidden, w, **params:
jax.vjp(lambda w_: _my_impl(x, w_, **params), w)[1](hidden)[0]
)
my_p.register_init_drtrl(
lambda x_var, y_var, w, ns:
jnp.zeros((x_var.aval.shape[0], *jnp.shape(w.value), ns))
)
my_p.register_init_pp(
lambda x_var, y_var, w, ns:
jnp.zeros((*y_var.aval.shape, ns), dtype=y_var.aval.dtype)
)
After registration the primitive is ready to use:
x = jnp.ones((4, 3))
w = jnp.ones((3, 5))
y = my_p.bind(x, w, scale=0.5) # all standard JAX rules work
gw = jax.grad(lambda w_: my_p.bind(x, w_, scale=0.5).sum())(w)
Spec-based example#
import jax
import jax.numpy as jnp
import braintrace
def _impl(x, w, *, scale=1.0):
return scale * (x @ w)
spec = braintrace.ETPPrimitiveSpec(
name='etp_my_op',
impl=_impl,
yw_to_w=lambda hidden, trace, **p: trace * hidden[None, :],
xy_to_dw=lambda x, hidden, w, **p:
jax.vjp(lambda w_: _impl(x, w_, **p), w)[1](hidden)[0],
init_drtrl=lambda x_var, y_var, w, ns:
jnp.zeros((x_var.aval.shape[0], *jnp.shape(w.value), ns)),
init_pp=lambda x_var, y_var, w, ns:
jnp.zeros((*y_var.aval.shape, ns), dtype=y_var.aval.dtype),
weight_invar_index=1,
x_invar_index=0,
batched=True,
)
my_p = braintrace.register_primitive_spec(spec)
# Later, the compiler can query the spec it was built from:
assert braintrace.get_primitive_spec(my_p) is spec
Auto-derived JAX rules#
You only need to write the four ETP rules above. All standard JAX
machinery is derived automatically from your impl:
abstract_eval— viajax.eval_shape(impl)MLIR lowering — via
mlir.lower_fun(impl)JVP — via
jax.jvp(impl)transpose — derived by JAX from the JVP
batching — via
jax.vmap(impl)
This means a custom primitive immediately works under
jit/grad/vmap/jvp without any extra code.
The gradient_enabled flag#
Pass gradient_enabled=True only for identity-like primitives
that may sit on the tail of the y -> h walk (the only built-in is
etp_elemwise_p).
For trainable ops, leave the default gradient_enabled=False. The
compiler treats such a primitive as a tail boundary: a preceding ETP
weight whose only path to h passes through it is correctly excluded
from ETP, because per-primitive ETP rules cannot express the
“weight-then-weight-then-hidden” composition.
See /tutorial/etp_primitives for a full walk-through and Compiler Internals for the underlying invariant.