brainevent.general_batching_rule

brainevent.general_batching_rule#

brainevent.general_batching_rule(prim, args, axes, **kwargs)[source]#

General-purpose batching rule for custom JAX primitives.

Implements batching by separating batched and non-batched arguments, moving all batch dimensions to axis 0, and then applying the primitive to each element in the batch via jax.lax.scan.

This function is registered as the default batching rule for every XLACustomKernel during initialization.

Parameters:
  • prim (Primitive) – The JAX primitive operation to be batched.

  • args (sequence of array_like) – Input arguments to the primitive.

  • axes (sequence of int or None) – Batch dimension index for each argument. None indicates that the corresponding argument is not batched.

  • **kwargs – Additional keyword arguments forwarded to the primitive.

Returns:

  • outs (pytree) – The batched outputs from applying the primitive.

  • out_dim (pytree) – A pytree with the same structure as outs, where every leaf is 0, indicating that the batch dimension is the leading axis of each output.

Notes

All batch dimensions are moved to axis 0 before scanning. The scan carry is unused (always 0); only the stacked scan outputs are returned.

See also

XLACustomKernel.register_general_batching

Registers this function as the batching rule for a primitive.

XLACustomKernel.def_batching_rule

Override with a custom batching rule.

Examples

>>> import functools
>>> from jax.interpreters import batching
>>> batching.primitive_batchers[my_prim] = functools.partial(
...     general_batching_rule, my_prim
... )