XLACustomKernel#
- class brainevent.XLACustomKernel(name, doc=None)#
Creates and manages a custom JAX primitive for XLA custom calls.
This class provides a high-level interface to define custom operations that can be executed efficiently on different backends (CPU, GPU, TPU) via XLA custom calls. It handles the registration of the JAX primitive, its abstract evaluation rule, backend-specific kernel implementations, and JAX transformation rules like batching, JVP (forward-mode AD), and transpose (reverse-mode AD).
Supported backends by platform:
CPU: Numba, CUDA
GPU: Pallas, CUDA, Numba CUDA, Warp, Triton
TPU: Pallas
The workflow for using this class is:
Create an instance with a unique primitive name
Register kernel implementations using
def_kernelor convenience methods likedef_numba_kernel,def_pallas_kernel, etc.Optionally set default backends using
set_defaultorasdefault=TrueDefine JAX transformation rules (batching, JVP, transpose) as needed
Call the instance with input arrays and output specifications
The first kernel registered for a platform automatically becomes the default. You can override this by calling
set_default(platform, backend)or by passingasdefault=Truewhen registering a kernel.If a kernel fails, the error message shows alternative backends available for the platform and how to switch to them.
Instance attributes:
primitive: The underlying JAX primitive created.name: The name assigned to the primitive.
- Parameters:
name (
str) – The unique name for the custom JAX primitive. This name is used to identify the primitive in JAX’s internal registry and in error messages.
See also
KernelEntryData class representing a single registered kernel.
defjvpUtility to define JVP rules for primitives with multiple results.
general_batching_ruleDefault batching rule applied to new
XLACustomKernelinstances.
Examples
>>> kernel = XLACustomKernel('my_custom_op') >>> kernel.def_numba_kernel(numba_kernel_generator) # CPU default >>> kernel.def_pallas_kernel('gpu', pallas_kernel_generator, asdefault=True) >>> kernel.def_warp_kernel(warp_kernel_generator) # Alternative GPU backend >>> print(kernel.defaults) # {'cpu': 'numba', 'gpu': 'pallas'} >>> kernel.set_default('gpu', 'warp') # Change GPU default >>> result = kernel(input_array, outs=[jax.ShapeDtypeStruct((10,), jnp.float32)])
- available_backends(platform)[source]#
Return the list of registered backend names for a platform.
- Parameters:
platform (
str) – The platform name (e.g.,'cpu','gpu','tpu').- Returns:
Backend names registered for platform. Returns an empty list if no kernels are registered for the platform.
- Return type:
See also
defaultsProperty returning the default backend for each platform.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_numba_kernel(numba_gen) >>> kernel.available_backends('cpu') ['numba']
- benchmark(*, platform, n_warmup=5, n_runs=20, n_batch_per_run=1, compare_results=True, rtol=0.001, atol=0.001, verbose=False, catch_errors=True, backends=None)[source]#
Benchmark all registered backends across every configured data config.
Iterates over the
BenchmarkConfiginstances produced by the registered benchmark data function, runs the call function on every registered backend for the given platform, collects timing statistics, and returns a unifiedBenchmarkResult.- Parameters:
platform (
str) – Target platform ('cpu','gpu', or'tpu').n_warmup (
int) – Number of warmup runs before timing. Default is5.n_runs (
int) – Number of timed runs per backend per config. Default is20.n_batch_per_run (
int) – Number of back-to-back kernel calls issued within each timed interval before blocking. Default is1(per-call latency). Higher values amortise blocking overhead, useful for measuring throughput on asynchronous GPU/TPU execution. Reported times are always per-call values.compare_results (
bool) – IfTrue(default), verify that outputs match across backends for each config usingjnp.allclose. Mismatches are printed as warnings.verbose (
bool) – IfTrue, print a one-line timing summary after each (config, backend) pair completes. Useful for monitoring progress when the config list is large. Default isFalse.rtol (
float) – Relative tolerance for output comparison when compare_results isTrue. Default is1e-3.atol (
float) – Absolute tolerance for output comparison when compare_results isTrue. Default is1e-3.catch_errors (
bool) – IfTrue(default), runtime errors raised by a backend during warmup or timed runs are caught and stored in the returnedBenchmarkRecordassuccess=Falsewith the exception message inerror. The benchmark continues with the remaining backends. Set toFalseto let exceptions propagate immediately, which is useful for debugging a specific backend failure.
- Returns:
A
BenchmarkResultcontaining oneBenchmarkRecordper (config × backend) pair. Failed runs (when catch_errors isTrue) are included withsuccess=False.- Return type:
- Raises:
BenchmarkDataFnNotProvidedError – If no benchmark data function has been registered. Use
def_benchmark_data()first.ValueError – If no call function has been registered (use
def_call()first), or if no backends are registered for platform.Exception – Any backend runtime error, when catch_errors is
False.
See also
def_callRegister the call function to benchmark.
def_benchmark_dataRegister a data generator for automated benchmarking.
BenchmarkResultUnified result container.
- call(*args, **kwargs)[source]#
Invoke the registered call function.
- Parameters:
*args – Positional arguments forwarded to the call function.
**kwargs – Keyword arguments forwarded to the call function.
- Returns:
The return value of the call function.
- Return type:
result
- Raises:
ValueError – If no call function has been registered via
def_call().
See also
def_callRegister a call function.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_call(my_call_fn) >>> result = kernel.call(x, y, backend='pallas')
- def_batching_rule(fun)[source]#
Define a custom batching rule for the primitive.
The batching rule specifies how the primitive should behave when applied to batched inputs (i.e., inputs with a leading batch dimension introduced by
jax.vmap).- Parameters:
fun (
Callable) – A function implementing the batching logic. It receives batched arguments and per-argument batch dimension indices, and must return(batched_outputs, output_batch_dims). See JAX documentation forjax.interpreters.batching.primitive_batchers.
See also
register_general_batchingRegister the default general-purpose batching rule.
general_batching_ruleThe general-purpose batching implementation used by default.
Examples
>>> def my_batching(args, axes, **kwargs): ... # Custom batching logic ... return batched_out, out_dims >>> kernel = XLACustomKernel('my_op') >>> kernel.def_batching_rule(my_batching)
- def_benchmark_data(fn)[source]#
Register a benchmark data generator function.
The generator produces a list of
BenchmarkConfiginstances that define the parameter combinations to benchmark for this primitive.- Parameters:
fn (
Callable) – A callable with signaturefn(*, platform: str) -> List[BenchmarkConfig].
See also
Examples
>>> def my_data_gen(*, platform): ... return [BenchmarkConfig(n=100), BenchmarkConfig(n=1000)] >>> kernel.def_benchmark_data(my_data_gen)
- def_call(fn)[source]#
Associate a high-level call function with this primitive.
The call function is the user-facing Python function that prepares arguments and invokes the primitive. It is stored so that
call()andbenchmark()can use it.- Parameters:
fn (
Callable) – The call function (e.g.,binary_csrmv_p_call). It should accept the same positional and keyword arguments that the end user would pass.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_call(my_call_fn) >>> kernel.call(x, y) # delegates to my_call_fn
- def_cuda_raw_kernel(kg, asdefault=False)[source]#
Register a cuda_raw (nvcc-compiled) kernel for the CPU or GPU platform.
Convenience wrapper around
def_kernel()withbackend='cuda_raw'. The kernel generator function should callbrainevent.load_cuda_file()orbrainevent.load_cuda_inline()to compile and register the CUDA kernel, then return a closure that calls it viajax.ffi.ffi_call.- Parameters:
See also
def_kernelGeneral kernel registration method.
- def_jvp_rule(fun)[source]#
Define a custom JVP (Jacobian-vector product) rule.
This rule is used for forward-mode automatic differentiation. It specifies how to compute the directional derivative of the primitive’s outputs with respect to its inputs.
- Parameters:
fun (
Callable) – A function implementing the JVP logic. See JAX documentation forjax.interpreters.ad.primitive_jvps.
See also
def_jvp_rule2Convenience method for defining per-input JVP rules.
def_transpose_ruleDefine a transpose rule for reverse-mode AD.
Examples
>>> def my_jvp(primals, tangents, **params): ... val_out = kernel.primitive.bind(*primals, **params) ... tangent_out = ... # compute tangent ... return val_out, tangent_out >>> kernel.def_jvp_rule(my_jvp)
- def_jvp_rule2(*jvp_rules)[source]#
Define per-input JVP rules for the primitive.
This is a convenience method similar to
jax.interpreters.ad.defjvpbut adapted for primitives that return multiple results. Each rule corresponds to one input primal and computes the tangent contribution from that input.- Parameters:
*jvp_rules (callable or None) – One callable per input primal. Each callable has the signature
rule(tangent, *primals, **params) -> tangent_out. PassNonefor inputs whose JVP contribution is zero.
See also
def_jvp_ruleDefine a single monolithic JVP rule.
defjvpThe underlying utility function.
Examples
>>> def jvp_input0(t, x, y, **kw): ... return t * y >>> def jvp_input1(t, x, y, **kw): ... return t * x >>> kernel.def_jvp_rule2(jvp_input0, jvp_input1)
- def_kernel(backend, platform, kg, asdefault=False)[source]#
Register a kernel implementation for a specific backend and platform.
Creates a
KernelEntryand stores it in the internal kernel registry. If this is the first kernel registered for the given platform, it automatically becomes the default. Passasdefault=Trueto override an existing default.A JAX lowering rule is registered for the platform the first time any kernel targets it.
- Parameters:
backend (
str) – The backend name (e.g.,'numba','warp','pallas','triton','cuda_raw','numba_cuda').platform (
str) – The hardware platform (e.g.,'cpu','gpu','tpu').kg (
Callable[...,Callable]) – A callable that accepts keyword arguments (from the primitivebindcall) and returns a concrete kernel function.asdefault (
bool) – IfTrue, set this backend as the default for platform even if a default already exists. Default isFalse.
- Raises:
AssertionError – If backend or platform is not a string, or if kg is not callable.
See also
def_numba_kernelShorthand for
def_kernel('numba', 'cpu', ...).def_warp_kernelShorthand for
def_kernel('warp', 'gpu', ...).def_pallas_kernelShorthand for Pallas kernels on GPU or TPU.
set_defaultChange the default backend for a platform after registration.
- def_numba_cuda_kernel(kg, asdefault=False)[source]#
Register a Numba CUDA kernel for the GPU platform.
Convenience wrapper around
def_kernel()withbackend='numba_cuda'andplatform='gpu'.- Parameters:
See also
def_kernelGeneral kernel registration method.
def_warp_kernelRegister a Warp kernel for GPU.
numba_cuda_kernelStandalone function for wrapping a single Numba CUDA kernel.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_numba_cuda_kernel(my_numba_cuda_gen)
- def_numba_kernel(kg, asdefault=False)[source]#
Register a Numba kernel for the CPU platform.
Convenience wrapper around
def_kernel()withbackend='numba'andplatform='cpu'.- Parameters:
See also
def_kernelGeneral kernel registration method.
def_warp_kernelRegister a Warp kernel for GPU.
def_pallas_kernelRegister a Pallas kernel for GPU or TPU.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_numba_kernel(my_numba_kernel_gen)
- def_pallas_kernel(platform, kg, asdefault=False)[source]#
Register a Pallas kernel for the GPU or TPU platform.
Convenience wrapper around
def_kernel()withbackend='pallas'.- Parameters:
- Raises:
AssertionError – If platform is not
'gpu'or'tpu'.
See also
def_kernelGeneral kernel registration method.
def_warp_kernelRegister a Warp kernel for GPU.
def_numba_kernelRegister a Numba kernel for CPU.
Notes
Pallas kernels require JAX >= 0.7.1. The version check is performed lazily at dispatch time, not at registration time.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_pallas_kernel('gpu', my_pallas_gen, asdefault=True)
- def_tags(*tags)[source]#
Set categorization tags for this primitive.
Tags are used by the CLI and the global primitive registry to filter primitives by sparse format (e.g.,
'csr','coo') and value type (e.g.,'binary','float').- Parameters:
*tags (
str) – Tag strings (e.g.,'csr','binary').
See also
def_benchmark_dataRegister a benchmark data generator.
Examples
>>> kernel = XLACustomKernel('binary_csrmv') >>> kernel.def_tags('csr', 'binary', 'mv')
- def_transpose_rule(fun)[source]#
Define a custom transpose rule for reverse-mode AD.
The transpose rule is invoked during
jax.linear_transposeand defines how to propagate cotangent vectors (gradients) backward through the primitive.- Parameters:
fun (
Callable) – A function implementing the transpose logic. See JAX documentation forjax.interpreters.ad.primitive_transposes.
See also
def_jvp_ruleDefine a JVP rule for forward-mode AD.
def_jvp_rule2Define per-input JVP rules.
Examples
>>> def my_transpose(ct, *args, **params): ... # Propagate cotangent backward ... return (ct_input0, ct_input1) >>> kernel.def_transpose_rule(my_transpose)
- def_triton_kernel(kg, asdefault=False)[source]#
Register a Triton kernel for the GPU platform.
Convenience wrapper around
def_kernel()withbackend='triton'andplatform='gpu'.- Parameters:
See also
def_kernelGeneral kernel registration method.
def_warp_kernelRegister a Warp kernel for GPU.
def_pallas_kernelRegister a Pallas kernel for GPU or TPU.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_triton_kernel(my_triton_kernel_gen)
- def_warp_kernel(kg, asdefault=False)[source]#
Register a Warp kernel for the GPU platform.
Convenience wrapper around
def_kernel()withbackend='warp'andplatform='gpu'.- Parameters:
See also
def_kernelGeneral kernel registration method.
def_numba_kernelRegister a Numba kernel for CPU.
def_pallas_kernelRegister a Pallas kernel for GPU or TPU.
def_triton_kernelRegister a Triton kernel for GPU.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_warp_kernel(my_warp_kernel_gen)
- property defaults: Dict[str, str]#
Return a copy of all default backends.
- Returns:
A dictionary mapping platform names to their default backend names. Modifying the returned dictionary does not affect the internal state.
- Return type:
dict of str to str
See also
get_defaultRetrieve the default for a single platform.
set_defaultChange the default backend for a platform.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_numba_kernel(numba_gen) >>> kernel.def_pallas_kernel('gpu', pallas_gen) >>> kernel.defaults {'cpu': 'numba', 'gpu': 'pallas'}
- get_default(platform)[source]#
Get the current default backend for a platform.
- Parameters:
platform (
str) – The platform name (e.g.,'cpu','gpu','tpu').- Returns:
The default backend name, or
Noneif no default is set for the given platform.- Return type:
See also
set_defaultSet the default backend for a platform.
defaultsProperty returning all default backends.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_numba_kernel(numba_gen) >>> kernel.get_default('cpu') 'numba'
- register_general_batching()[source]#
Register the default general-purpose batching rule.
This method applies a common batching pattern that handles most custom operators by using
jax.lax.scanto map the kernel over the batch dimension. It is called automatically during__init__; call it again to restore the default after overriding withdef_batching_rule().See also
def_batching_ruleOverride with a custom batching rule.
general_batching_ruleThe underlying batching implementation.
Notes
The general batching rule moves all batch dimensions to axis 0 and uses
jax.lax.scanto iterate over the batch. This is correct but may be slower than a hand-written batching rule for operations that can natively handle batched inputs.Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_batching_rule(custom_rule) >>> kernel.register_general_batching() # Restore default
- set_default(platform, backend)[source]#
Set the default backend for a platform.
After this call, all subsequent dispatches on the given platform will use the specified backend unless overridden by an explicit
backend=keyword argument at call time.- Parameters:
- Raises:
ValueError – If no kernels are registered for platform, or if backend is not registered for platform.
See also
get_defaultRetrieve the current default backend for a platform.
defaultsProperty returning all default backends.
Examples
>>> kernel = XLACustomKernel('my_op') >>> kernel.def_warp_kernel(warp_gen) >>> kernel.def_pallas_kernel('gpu', pallas_gen) >>> kernel.set_default('gpu', 'pallas')