brainevent.defjvp

Contents

brainevent.defjvp#

brainevent.defjvp(primitive, *jvp_rules)[source]#

Define per-input JVP rules for a JAX primitive.

This function allows defining Jacobian-vector product (JVP) rules for JAX primitives, extending the functionality of jax.interpreters.ad.defjvp. While the standard JAX function primarily supports primitives that return a single result (multiple_results=False), this implementation supports defining independent JVP rules for each input parameter regardless of whether the primitive returns single or multiple results.

This is particularly useful for custom operations or primitives where different inputs have different differentiation rules, or where the primitive produces multiple outputs that need distinct handling in automatic differentiation.

Parameters:
  • primitive (Primitive or XLACustomKernel) – The JAX Primitive object (or an XLACustomKernel instance, from which the underlying Primitive is extracted) for which the JVP rule is being defined.

  • *jvp_rules (callable or None) – One callable per input primal. Each callable has the signature rule(tangent, *primals, **params) -> tangent_out and computes the tangent contribution from the corresponding input. Pass None for inputs whose JVP contribution is zero.

Raises:

AssertionError – If primitive (after unwrapping) is not a Primitive instance.

See also

XLACustomKernel.def_jvp_rule

Register a single monolithic JVP rule for an XLACustomKernel.

XLACustomKernel.def_jvp_rule2

Convenience wrapper around this function.

general_batching_rule

General batching rule for custom primitives.

Notes

When the primitive has multiple_results=True, a custom internal JVP implementation (_standard_jvp) is used that correctly handles tuple outputs. For single-result primitives, the standard jax.interpreters.ad.standard_jvp is used.

Examples

>>> # Assume `my_prim` is a JAX Primitive with two inputs.
>>> def jvp_rule_input0(tangent, x, y, **kw):
...     return tangent * y
>>> def jvp_rule_input1(tangent, x, y, **kw):
...     return tangent * x
>>> defjvp(my_prim, jvp_rule_input0, jvp_rule_input1)