XLACustomKernel#

class brainevent.XLACustomKernel(name, doc=None)#

Creates and manages a custom JAX primitive for XLA custom calls.

This class provides a high-level interface to define custom operations that can be executed efficiently on different backends (CPU, GPU, TPU) via XLA custom calls. It handles the registration of the JAX primitive, its abstract evaluation rule, backend-specific kernel implementations, and JAX transformation rules like batching, JVP (forward-mode AD), and transpose (reverse-mode AD).

Supported backends by platform:

  • CPU: Numba, CUDA

  • GPU: Pallas, CUDA, Numba CUDA, Warp, Triton

  • TPU: Pallas

The workflow for using this class is:

  1. Create an instance with a unique primitive name

  2. Register kernel implementations using def_kernel or convenience methods like def_numba_kernel, def_pallas_kernel, etc.

  3. Optionally set default backends using set_default or asdefault=True

  4. Define JAX transformation rules (batching, JVP, transpose) as needed

  5. Call the instance with input arrays and output specifications

The first kernel registered for a platform automatically becomes the default. You can override this by calling set_default(platform, backend) or by passing asdefault=True when registering a kernel.

If a kernel fails, the error message shows alternative backends available for the platform and how to switch to them.

Instance attributes:

  • primitive: The underlying JAX primitive created.

  • name: The name assigned to the primitive.

Parameters:

name (str) – The unique name for the custom JAX primitive. This name is used to identify the primitive in JAX’s internal registry and in error messages.

See also

KernelEntry

Data class representing a single registered kernel.

defjvp

Utility to define JVP rules for primitives with multiple results.

general_batching_rule

Default batching rule applied to new XLACustomKernel instances.

Examples

>>> kernel = XLACustomKernel('my_custom_op')
>>> kernel.def_numba_kernel(numba_kernel_generator)  # CPU default
>>> kernel.def_pallas_kernel('gpu', pallas_kernel_generator, asdefault=True)
>>> kernel.def_warp_kernel(warp_kernel_generator)  # Alternative GPU backend
>>> print(kernel.defaults)  # {'cpu': 'numba', 'gpu': 'pallas'}
>>> kernel.set_default('gpu', 'warp')  # Change GPU default
>>> result = kernel(input_array, outs=[jax.ShapeDtypeStruct((10,), jnp.float32)])
available_backends(platform)[source]#

Return the list of registered backend names for a platform.

Parameters:

platform (str) – The platform name (e.g., 'cpu', 'gpu', 'tpu').

Returns:

Backend names registered for platform. Returns an empty list if no kernels are registered for the platform.

Return type:

List[str]

See also

defaults

Property returning the default backend for each platform.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_numba_kernel(numba_gen)
>>> kernel.available_backends('cpu')
['numba']
benchmark(*, platform, n_warmup=5, n_runs=20, n_batch_per_run=1, compare_results=True, rtol=0.001, atol=0.001, verbose=False, catch_errors=True, backends=None)[source]#

Benchmark all registered backends across every configured data config.

Iterates over the BenchmarkConfig instances produced by the registered benchmark data function, runs the call function on every registered backend for the given platform, collects timing statistics, and returns a unified BenchmarkResult.

Parameters:
  • platform (str) – Target platform ('cpu', 'gpu', or 'tpu').

  • n_warmup (int) – Number of warmup runs before timing. Default is 5.

  • n_runs (int) – Number of timed runs per backend per config. Default is 20.

  • n_batch_per_run (int) – Number of back-to-back kernel calls issued within each timed interval before blocking. Default is 1 (per-call latency). Higher values amortise blocking overhead, useful for measuring throughput on asynchronous GPU/TPU execution. Reported times are always per-call values.

  • compare_results (bool) – If True (default), verify that outputs match across backends for each config using jnp.allclose. Mismatches are printed as warnings.

  • verbose (bool) – If True, print a one-line timing summary after each (config, backend) pair completes. Useful for monitoring progress when the config list is large. Default is False.

  • rtol (float) – Relative tolerance for output comparison when compare_results is True. Default is 1e-3.

  • atol (float) – Absolute tolerance for output comparison when compare_results is True. Default is 1e-3.

  • catch_errors (bool) – If True (default), runtime errors raised by a backend during warmup or timed runs are caught and stored in the returned BenchmarkRecord as success=False with the exception message in error. The benchmark continues with the remaining backends. Set to False to let exceptions propagate immediately, which is useful for debugging a specific backend failure.

Returns:

A BenchmarkResult containing one BenchmarkRecord per (config × backend) pair. Failed runs (when catch_errors is True) are included with success=False.

Return type:

BenchmarkResult

Raises:
  • BenchmarkDataFnNotProvidedError – If no benchmark data function has been registered. Use def_benchmark_data() first.

  • ValueError – If no call function has been registered (use def_call() first), or if no backends are registered for platform.

  • Exception – Any backend runtime error, when catch_errors is False.

See also

def_call

Register the call function to benchmark.

def_benchmark_data

Register a data generator for automated benchmarking.

BenchmarkResult

Unified result container.

call(*args, **kwargs)[source]#

Invoke the registered call function.

Parameters:
  • *args – Positional arguments forwarded to the call function.

  • **kwargs – Keyword arguments forwarded to the call function.

Returns:

The return value of the call function.

Return type:

result

Raises:

ValueError – If no call function has been registered via def_call().

See also

def_call

Register a call function.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_call(my_call_fn)
>>> result = kernel.call(x, y, backend='pallas')
def_batching_rule(fun)[source]#

Define a custom batching rule for the primitive.

The batching rule specifies how the primitive should behave when applied to batched inputs (i.e., inputs with a leading batch dimension introduced by jax.vmap).

Parameters:

fun (Callable) – A function implementing the batching logic. It receives batched arguments and per-argument batch dimension indices, and must return (batched_outputs, output_batch_dims). See JAX documentation for jax.interpreters.batching.primitive_batchers.

See also

register_general_batching

Register the default general-purpose batching rule.

general_batching_rule

The general-purpose batching implementation used by default.

Examples

>>> def my_batching(args, axes, **kwargs):
...     # Custom batching logic
...     return batched_out, out_dims
>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_batching_rule(my_batching)
def_benchmark_data(fn)[source]#

Register a benchmark data generator function.

The generator produces a list of BenchmarkConfig instances that define the parameter combinations to benchmark for this primitive.

Parameters:

fn (Callable) – A callable with signature fn(*, platform: str) -> List[BenchmarkConfig].

See also

benchmark

Run benchmarks using the registered call function.

def_tags

Set categorization tags for filtering.

Examples

>>> def my_data_gen(*, platform):
...     return [BenchmarkConfig(n=100), BenchmarkConfig(n=1000)]
>>> kernel.def_benchmark_data(my_data_gen)
def_call(fn)[source]#

Associate a high-level call function with this primitive.

The call function is the user-facing Python function that prepares arguments and invokes the primitive. It is stored so that call() and benchmark() can use it.

Parameters:

fn (Callable) – The call function (e.g., binary_csrmv_p_call). It should accept the same positional and keyword arguments that the end user would pass.

See also

call

Invoke the registered call function.

benchmark

Benchmark the registered call function.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_call(my_call_fn)
>>> kernel.call(x, y)  # delegates to my_call_fn
def_cuda_raw_kernel(kg, asdefault=False)[source]#

Register a cuda_raw (nvcc-compiled) kernel for the CPU or GPU platform.

Convenience wrapper around def_kernel() with backend='cuda_raw'. The kernel generator function should call brainevent.load_cuda_file() or brainevent.load_cuda_inline() to compile and register the CUDA kernel, then return a closure that calls it via jax.ffi.ffi_call.

Parameters:
  • platform (str) – Target platform. Must be 'cpu' or 'gpu'.

  • kg (Callable[..., Callable]) – A callable that compiles and returns the kernel function.

  • asdefault (bool) – If True, set cuda_raw as the default backend for the given platform. Default is False.

See also

def_kernel

General kernel registration method.

def_jvp_rule(fun)[source]#

Define a custom JVP (Jacobian-vector product) rule.

This rule is used for forward-mode automatic differentiation. It specifies how to compute the directional derivative of the primitive’s outputs with respect to its inputs.

Parameters:

fun (Callable) – A function implementing the JVP logic. See JAX documentation for jax.interpreters.ad.primitive_jvps.

See also

def_jvp_rule2

Convenience method for defining per-input JVP rules.

def_transpose_rule

Define a transpose rule for reverse-mode AD.

Examples

>>> def my_jvp(primals, tangents, **params):
...     val_out = kernel.primitive.bind(*primals, **params)
...     tangent_out = ...  # compute tangent
...     return val_out, tangent_out
>>> kernel.def_jvp_rule(my_jvp)
def_jvp_rule2(*jvp_rules)[source]#

Define per-input JVP rules for the primitive.

This is a convenience method similar to jax.interpreters.ad.defjvp but adapted for primitives that return multiple results. Each rule corresponds to one input primal and computes the tangent contribution from that input.

Parameters:

*jvp_rules (callable or None) – One callable per input primal. Each callable has the signature rule(tangent, *primals, **params) -> tangent_out. Pass None for inputs whose JVP contribution is zero.

See also

def_jvp_rule

Define a single monolithic JVP rule.

defjvp

The underlying utility function.

Examples

>>> def jvp_input0(t, x, y, **kw):
...     return t * y
>>> def jvp_input1(t, x, y, **kw):
...     return t * x
>>> kernel.def_jvp_rule2(jvp_input0, jvp_input1)
def_kernel(backend, platform, kg, asdefault=False)[source]#

Register a kernel implementation for a specific backend and platform.

Creates a KernelEntry and stores it in the internal kernel registry. If this is the first kernel registered for the given platform, it automatically becomes the default. Pass asdefault=True to override an existing default.

A JAX lowering rule is registered for the platform the first time any kernel targets it.

Parameters:
  • backend (str) – The backend name (e.g., 'numba', 'warp', 'pallas', 'triton', 'cuda_raw', 'numba_cuda').

  • platform (str) – The hardware platform (e.g., 'cpu', 'gpu', 'tpu').

  • kg (Callable[..., Callable]) – A callable that accepts keyword arguments (from the primitive bind call) and returns a concrete kernel function.

  • asdefault (bool) – If True, set this backend as the default for platform even if a default already exists. Default is False.

Raises:

AssertionError – If backend or platform is not a string, or if kg is not callable.

See also

def_numba_kernel

Shorthand for def_kernel('numba', 'cpu', ...).

def_warp_kernel

Shorthand for def_kernel('warp', 'gpu', ...).

def_pallas_kernel

Shorthand for Pallas kernels on GPU or TPU.

set_default

Change the default backend for a platform after registration.

def_numba_cuda_kernel(kg, asdefault=False)[source]#

Register a Numba CUDA kernel for the GPU platform.

Convenience wrapper around def_kernel() with backend='numba_cuda' and platform='gpu'.

Parameters:
  • kg (Callable[..., Callable]) – A callable that generates the Numba CUDA kernel function.

  • asdefault (bool) – If True, set Numba CUDA as the default GPU backend. Default is False.

See also

def_kernel

General kernel registration method.

def_warp_kernel

Register a Warp kernel for GPU.

numba_cuda_kernel

Standalone function for wrapping a single Numba CUDA kernel.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_numba_cuda_kernel(my_numba_cuda_gen)
def_numba_kernel(kg, asdefault=False)[source]#

Register a Numba kernel for the CPU platform.

Convenience wrapper around def_kernel() with backend='numba' and platform='cpu'.

Parameters:
  • kg (Callable[..., Callable]) – A callable that generates the Numba kernel function.

  • asdefault (bool) – If True, set Numba as the default CPU backend. Default is False (the first registered CPU kernel becomes the default automatically).

See also

def_kernel

General kernel registration method.

def_warp_kernel

Register a Warp kernel for GPU.

def_pallas_kernel

Register a Pallas kernel for GPU or TPU.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_numba_kernel(my_numba_kernel_gen)
def_pallas_kernel(platform, kg, asdefault=False)[source]#

Register a Pallas kernel for the GPU or TPU platform.

Convenience wrapper around def_kernel() with backend='pallas'.

Parameters:
  • platform (str) – Target platform. Must be 'gpu' or 'tpu'.

  • kg (Callable[..., Callable]) – A callable that generates the Pallas kernel function.

  • asdefault (bool) – If True, set Pallas as the default backend for the given platform. Default is False.

Raises:

AssertionError – If platform is not 'gpu' or 'tpu'.

See also

def_kernel

General kernel registration method.

def_warp_kernel

Register a Warp kernel for GPU.

def_numba_kernel

Register a Numba kernel for CPU.

Notes

Pallas kernels require JAX >= 0.7.1. The version check is performed lazily at dispatch time, not at registration time.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_pallas_kernel('gpu', my_pallas_gen, asdefault=True)
def_tags(*tags)[source]#

Set categorization tags for this primitive.

Tags are used by the CLI and the global primitive registry to filter primitives by sparse format (e.g., 'csr', 'coo') and value type (e.g., 'binary', 'float').

Parameters:

*tags (str) – Tag strings (e.g., 'csr', 'binary').

See also

def_benchmark_data

Register a benchmark data generator.

Examples

>>> kernel = XLACustomKernel('binary_csrmv')
>>> kernel.def_tags('csr', 'binary', 'mv')
def_transpose_rule(fun)[source]#

Define a custom transpose rule for reverse-mode AD.

The transpose rule is invoked during jax.linear_transpose and defines how to propagate cotangent vectors (gradients) backward through the primitive.

Parameters:

fun (Callable) – A function implementing the transpose logic. See JAX documentation for jax.interpreters.ad.primitive_transposes.

See also

def_jvp_rule

Define a JVP rule for forward-mode AD.

def_jvp_rule2

Define per-input JVP rules.

Examples

>>> def my_transpose(ct, *args, **params):
...     # Propagate cotangent backward
...     return (ct_input0, ct_input1)
>>> kernel.def_transpose_rule(my_transpose)
def_triton_kernel(kg, asdefault=False)[source]#

Register a Triton kernel for the GPU platform.

Convenience wrapper around def_kernel() with backend='triton' and platform='gpu'.

Parameters:
  • kg (Callable[..., Callable]) – A callable that generates the Triton kernel function.

  • asdefault (bool) – If True, set Triton as the default GPU backend. Default is False.

See also

def_kernel

General kernel registration method.

def_warp_kernel

Register a Warp kernel for GPU.

def_pallas_kernel

Register a Pallas kernel for GPU or TPU.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_triton_kernel(my_triton_kernel_gen)
def_warp_kernel(kg, asdefault=False)[source]#

Register a Warp kernel for the GPU platform.

Convenience wrapper around def_kernel() with backend='warp' and platform='gpu'.

Parameters:
  • kg (Callable[..., Callable]) – A callable that generates the Warp kernel function.

  • asdefault (bool) – If True, set Warp as the default GPU backend. Default is False.

See also

def_kernel

General kernel registration method.

def_numba_kernel

Register a Numba kernel for CPU.

def_pallas_kernel

Register a Pallas kernel for GPU or TPU.

def_triton_kernel

Register a Triton kernel for GPU.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_warp_kernel(my_warp_kernel_gen)
property defaults: Dict[str, str]#

Return a copy of all default backends.

Returns:

A dictionary mapping platform names to their default backend names. Modifying the returned dictionary does not affect the internal state.

Return type:

dict of str to str

See also

get_default

Retrieve the default for a single platform.

set_default

Change the default backend for a platform.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_numba_kernel(numba_gen)
>>> kernel.def_pallas_kernel('gpu', pallas_gen)
>>> kernel.defaults
{'cpu': 'numba', 'gpu': 'pallas'}
get_default(platform)[source]#

Get the current default backend for a platform.

Parameters:

platform (str) – The platform name (e.g., 'cpu', 'gpu', 'tpu').

Returns:

The default backend name, or None if no default is set for the given platform.

Return type:

str | None

See also

set_default

Set the default backend for a platform.

defaults

Property returning all default backends.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_numba_kernel(numba_gen)
>>> kernel.get_default('cpu')
'numba'
register_general_batching()[source]#

Register the default general-purpose batching rule.

This method applies a common batching pattern that handles most custom operators by using jax.lax.scan to map the kernel over the batch dimension. It is called automatically during __init__; call it again to restore the default after overriding with def_batching_rule().

See also

def_batching_rule

Override with a custom batching rule.

general_batching_rule

The underlying batching implementation.

Notes

The general batching rule moves all batch dimensions to axis 0 and uses jax.lax.scan to iterate over the batch. This is correct but may be slower than a hand-written batching rule for operations that can natively handle batched inputs.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_batching_rule(custom_rule)
>>> kernel.register_general_batching()  # Restore default
set_default(platform, backend)[source]#

Set the default backend for a platform.

After this call, all subsequent dispatches on the given platform will use the specified backend unless overridden by an explicit backend= keyword argument at call time.

Parameters:
  • platform (str) – The platform name (e.g., 'cpu', 'gpu', 'tpu').

  • backend (str) – The backend name to set as default. Must already be registered for the given platform.

Raises:

ValueError – If no kernels are registered for platform, or if backend is not registered for platform.

See also

get_default

Retrieve the current default backend for a platform.

defaults

Property returning all default backends.

Examples

>>> kernel = XLACustomKernel('my_op')
>>> kernel.def_warp_kernel(warp_gen)
>>> kernel.def_pallas_kernel('gpu', pallas_gen)
>>> kernel.set_default('gpu', 'pallas')