brainevent.general_batching_rule#
- brainevent.general_batching_rule(prim, args, axes, **kwargs)[source]#
General-purpose batching rule for custom JAX primitives.
Implements batching by separating batched and non-batched arguments, moving all batch dimensions to axis 0, and then applying the primitive to each element in the batch via
jax.lax.scan.This function is registered as the default batching rule for every
XLACustomKernelduring initialization.- Parameters:
prim (Primitive) – The JAX primitive operation to be batched.
args (sequence of array_like) – Input arguments to the primitive.
axes (sequence of int or None) – Batch dimension index for each argument.
Noneindicates that the corresponding argument is not batched.**kwargs – Additional keyword arguments forwarded to the primitive.
- Returns:
outs (pytree) – The batched outputs from applying the primitive.
out_dim (pytree) – A pytree with the same structure as outs, where every leaf is
0, indicating that the batch dimension is the leading axis of each output.
Notes
All batch dimensions are moved to axis 0 before scanning. The scan carry is unused (always
0); only the stacked scan outputs are returned.See also
XLACustomKernel.register_general_batchingRegisters this function as the batching rule for a primitive.
XLACustomKernel.def_batching_ruleOverride with a custom batching rule.
Examples
>>> import functools >>> from jax.interpreters import batching >>> batching.primitive_batchers[my_prim] = functools.partial( ... general_batching_rule, my_prim ... )