brainevent.defjvp#
- brainevent.defjvp(primitive, *jvp_rules)[source]#
Define per-input JVP rules for a JAX primitive.
This function allows defining Jacobian-vector product (JVP) rules for JAX primitives, extending the functionality of
jax.interpreters.ad.defjvp. While the standard JAX function primarily supports primitives that return a single result (multiple_results=False), this implementation supports defining independent JVP rules for each input parameter regardless of whether the primitive returns single or multiple results.This is particularly useful for custom operations or primitives where different inputs have different differentiation rules, or where the primitive produces multiple outputs that need distinct handling in automatic differentiation.
- Parameters:
primitive (Primitive or XLACustomKernel) – The JAX
Primitiveobject (or anXLACustomKernelinstance, from which the underlyingPrimitiveis extracted) for which the JVP rule is being defined.*jvp_rules (callable or None) – One callable per input primal. Each callable has the signature
rule(tangent, *primals, **params) -> tangent_outand computes the tangent contribution from the corresponding input. PassNonefor inputs whose JVP contribution is zero.
- Raises:
AssertionError – If primitive (after unwrapping) is not a
Primitiveinstance.
See also
XLACustomKernel.def_jvp_ruleRegister a single monolithic JVP rule for an
XLACustomKernel.XLACustomKernel.def_jvp_rule2Convenience wrapper around this function.
general_batching_ruleGeneral batching rule for custom primitives.
Notes
When the primitive has
multiple_results=True, a custom internal JVP implementation (_standard_jvp) is used that correctly handles tuple outputs. For single-result primitives, the standardjax.interpreters.ad.standard_jvpis used.Examples
>>> # Assume `my_prim` is a JAX Primitive with two inputs. >>> def jvp_rule_input0(tangent, x, y, **kw): ... return tangent * y >>> def jvp_rule_input1(tangent, x, y, **kw): ... return tangent * x >>> defjvp(my_prim, jvp_rule_input0, jvp_rule_input1)