brainevent.numba_cuda_kernel

brainevent.numba_cuda_kernel#

brainevent.numba_cuda_kernel(kernel, outs, *, grid=None, block=None, launch_dims=None, threads_per_block=256, shared_mem=0, vmap_method=None, input_output_aliases=None)[source]#

Create a JAX-callable function from a single Numba CUDA kernel.

Wraps a Numba CUDA kernel (decorated with @cuda.jit) so that it can be called from JAX on GPU. The kernel operates on device memory directly with zero-copy access via XLA’s typed FFI protocol.

Either (grid, block) or launch_dims must be specified to configure the CUDA launch. When launch_dims is used, the grid and block dimensions are computed automatically.

Parameters:
  • kernel (numba.cuda.compiler.CUDADispatcher) – A Numba CUDA kernel function decorated with @cuda.jit.

  • outs (ShapeDtype | Sequence[ShapeDtype]) – Output specification. A single jax.ShapeDtypeStruct or a sequence/pytree of them for multiple outputs.

  • grid (int | Tuple[int, ...] | None) – Grid dimensions for the kernel launch. Must be specified together with block. Mutually exclusive with launch_dims.

  • block (int | Tuple[int, ...] | None) – Block dimensions for the kernel launch. Must be specified together with grid.

  • launch_dims (int | Tuple[int, ...] | None) – Total number of threads to launch. Grid and block are computed automatically. Mutually exclusive with (grid, block).

  • threads_per_block (int) – Number of threads per block when using launch_dims. Default is 256.

  • shared_mem (int) – Dynamic shared memory size in bytes. Default is 0.

  • vmap_method (str | None) – Method to use for jax.vmap. Passed directly to jax.ffi.ffi_call.

  • input_output_aliases (dict[int, int] | None) – Mapping from input index to output index for in-place operations. Passed directly to jax.ffi.ffi_call.

Returns:

A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside jax.jit-compiled code.

Return type:

callable

Raises:
  • ImportError – If Numba with CUDA support is not available.

  • ValueError – If neither (grid, block) nor launch_dims is specified.

  • AssertionError – If kernel is not a numba.cuda.dispatcher.CUDADispatcher.

See also

numba_cuda_callable

Wrap an arbitrary Python callable that launches multiple Numba CUDA kernels.

XLACustomKernel.def_numba_cuda_kernel

Register a Numba CUDA kernel with an XLACustomKernel.

Notes

Each call to the returned function registers a new FFI target. For performance-critical inner loops, consider caching the returned callable.

Examples

>>> from numba import cuda
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> @cuda.jit
... def add_kernel(x, y, out):
...     i = cuda.grid(1)
...     if i < out.size:
...         out[i] = x[i] + y[i]
>>>
>>> # Option 1: Explicit grid/block
>>> kernel_fn = numba_cuda_kernel(
...     add_kernel,
...     outs=jax.ShapeDtypeStruct((1024,), jnp.float32),
...     grid=4,
...     block=256,
... )
>>>
>>> # Option 2: Auto grid from launch_dims
>>> kernel_fn = numba_cuda_kernel(
...     add_kernel,
...     outs=jax.ShapeDtypeStruct((1024,), jnp.float32),
...     launch_dims=1024,
... )
>>>
>>> @jax.jit
... def f(a, b):
...     return kernel_fn(a, b)