brainevent.numba_cuda_kernel#
- brainevent.numba_cuda_kernel(kernel, outs, *, grid=None, block=None, launch_dims=None, threads_per_block=256, shared_mem=0, vmap_method=None, input_output_aliases=None)[source]#
Create a JAX-callable function from a single Numba CUDA kernel.
Wraps a Numba CUDA kernel (decorated with
@cuda.jit) so that it can be called from JAX on GPU. The kernel operates on device memory directly with zero-copy access via XLA’s typed FFI protocol.Either
(grid, block)orlaunch_dimsmust be specified to configure the CUDA launch. Whenlaunch_dimsis used, the grid and block dimensions are computed automatically.- Parameters:
kernel (numba.cuda.compiler.CUDADispatcher) – A Numba CUDA kernel function decorated with
@cuda.jit.outs (
ShapeDtype|Sequence[ShapeDtype]) – Output specification. A singlejax.ShapeDtypeStructor a sequence/pytree of them for multiple outputs.grid (
int|Tuple[int,...] |None) – Grid dimensions for the kernel launch. Must be specified together with block. Mutually exclusive with launch_dims.block (
int|Tuple[int,...] |None) – Block dimensions for the kernel launch. Must be specified together with grid.launch_dims (
int|Tuple[int,...] |None) – Total number of threads to launch. Grid and block are computed automatically. Mutually exclusive with (grid, block).threads_per_block (
int) – Number of threads per block when using launch_dims. Default is256.shared_mem (
int) – Dynamic shared memory size in bytes. Default is0.vmap_method (
str|None) – Method to use forjax.vmap. Passed directly tojax.ffi.ffi_call.input_output_aliases (
dict[int,int] |None) – Mapping from input index to output index for in-place operations. Passed directly tojax.ffi.ffi_call.
- Returns:
A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside
jax.jit-compiled code.- Return type:
callable
- Raises:
ImportError – If Numba with CUDA support is not available.
ValueError – If neither
(grid, block)norlaunch_dimsis specified.AssertionError – If kernel is not a
numba.cuda.dispatcher.CUDADispatcher.
See also
numba_cuda_callableWrap an arbitrary Python callable that launches multiple Numba CUDA kernels.
XLACustomKernel.def_numba_cuda_kernelRegister a Numba CUDA kernel with an
XLACustomKernel.
Notes
Each call to the returned function registers a new FFI target. For performance-critical inner loops, consider caching the returned callable.
Examples
>>> from numba import cuda >>> import jax >>> import jax.numpy as jnp >>> >>> @cuda.jit ... def add_kernel(x, y, out): ... i = cuda.grid(1) ... if i < out.size: ... out[i] = x[i] + y[i] >>> >>> # Option 1: Explicit grid/block >>> kernel_fn = numba_cuda_kernel( ... add_kernel, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... grid=4, ... block=256, ... ) >>> >>> # Option 2: Auto grid from launch_dims >>> kernel_fn = numba_cuda_kernel( ... add_kernel, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... launch_dims=1024, ... ) >>> >>> @jax.jit ... def f(a, b): ... return kernel_fn(a, b)