braintrace.register_primitive

braintrace.register_primitive#

braintrace.register_primitive(name, impl_fn, *, batched=False, gradient_enabled=False)[source]#

Create an ETPPrimitive with all JAX rules auto-derived.

The following rules are installed automatically:

  • impl — eager execution

  • abstract_eval — via jax.eval_shape(impl)

  • lowering — via mlir.lower_fun(impl)

  • JVP — via jax.jvp(impl)

  • transpose — derived by JAX from the JVP

  • batching — via jax.vmap(impl)

Only the four ETP-specific rules need hand-writing — call the returned primitive’s register_* methods.

Parameters:
  • name – Primitive name (e.g. 'etp_mm').

  • impl_fn – Implementation function.

  • batched – Whether this primitive operates on batched inputs.

  • gradient_enabled – If True, the compiler will evaluate this primitive when walking y -> h (identity-like ops such as etp_elemwise_p).

Returns:

the registered primitive.

Return type:

ETPPrimitive