braintrace.register_primitive#
- braintrace.register_primitive(name, impl_fn, *, batched=False, gradient_enabled=False)[source]#
Create an
ETPPrimitivewith all JAX rules auto-derived.The following rules are installed automatically:
impl — eager execution
abstract_eval — via
jax.eval_shape(impl)lowering — via
mlir.lower_fun(impl)JVP — via
jax.jvp(impl)transpose — derived by JAX from the JVP
batching — via
jax.vmap(impl)
Only the four ETP-specific rules need hand-writing — call the returned primitive’s
register_*methods.- Parameters:
name – Primitive name (e.g.
'etp_mm').impl_fn – Implementation function.
batched – Whether this primitive operates on batched inputs.
gradient_enabled – If True, the compiler will evaluate this primitive when walking
y -> h(identity-like ops such asetp_elemwise_p).
- Returns:
the registered primitive.
- Return type: