brainevent.numba_cuda_callable

brainevent.numba_cuda_callable#

brainevent.numba_cuda_callable(func, outs, *, vmap_method=None, input_output_aliases=None)[source]#

Create a JAX-callable from a Python function that launches Numba CUDA kernels.

Unlike numba_cuda_kernel() (which wraps a single @cuda.jit kernel), this function wraps an arbitrary Python callable. The callable receives Numba CUDA device arrays for inputs and outputs, plus a Numba CUDA stream, and may launch any number of kernels, allocate temporary device memory, or perform multi-step GPU computations.

The wrapped function must have the signature:

func(input_1, input_2, ..., output_1, output_2, ..., stream)

where every input_* and output_* is a Numba CUDA device array and stream is a Numba CUDA stream obtained from XLA.

Parameters:
  • func (Callable) – A Python function with the signature described above.

  • outs (ShapeDtype | Sequence[ShapeDtype]) – Output specification. A single jax.ShapeDtypeStruct or a sequence/pytree of them for multiple outputs.

  • vmap_method (str | None) – How to handle jax.vmap. Passed directly to jax.ffi.ffi_call.

  • input_output_aliases (dict[int, int] | None) – Mapping from input index to output index for in-place operations. Passed directly to jax.ffi.ffi_call.

Returns:

A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside jax.jit-compiled code.

Return type:

callable

Raises:
  • ImportError – If Numba with CUDA support is not available.

  • TypeError – If func is not callable.

  • ValueError – If any input array is a 0-d (scalar) array, which is not supported by Numba CUDA device arrays.

See also

numba_cuda_kernel

Wrap a single @cuda.jit kernel with fixed grid/block configuration.

XLACustomKernel.def_numba_cuda_kernel

Register a Numba CUDA kernel with an XLACustomKernel.

Notes

Each call to the returned function registers a new FFI target. For performance-critical inner loops, consider caching the returned callable.

Scalar (0-d) inputs are not supported because Numba CUDA cannot create device arrays from 0-d buffers. Wrap scalar values in 1-d arrays (e.g., jnp.array([value])) before passing them.

Examples

>>> from numba import cuda
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> @cuda.jit
... def add_kernel(x, y, temp, n):
...     i = cuda.grid(1)
...     if i < n:
...         temp[i] = x[i] + y[i]
>>>
>>> @cuda.jit
... def scale_kernel(temp, out, scale, n):
...     i = cuda.grid(1)
...     if i < n:
...         out[i] = temp[i] * scale
>>>
>>> def my_op(x, y, out, stream):
...     n = x.shape[0]
...     temp = cuda.device_array(n, dtype=x.dtype)
...     threads = 256
...     blocks = (n + threads - 1) // threads
...     add_kernel[blocks, threads, stream](x, y, temp, n)
...     scale_kernel[blocks, threads, stream](temp, out, 2.0, n)
>>>
>>> fn = numba_cuda_callable(
...     my_op,
...     outs=jax.ShapeDtypeStruct((1024,), jnp.float32),
... )
>>>
>>> @jax.jit
... def f(a, b):
...     return fn(a, b)