brainevent.numba_cuda_callable#
- brainevent.numba_cuda_callable(func, outs, *, vmap_method=None, input_output_aliases=None)[source]#
Create a JAX-callable from a Python function that launches Numba CUDA kernels.
Unlike
numba_cuda_kernel()(which wraps a single@cuda.jitkernel), this function wraps an arbitrary Python callable. The callable receives Numba CUDA device arrays for inputs and outputs, plus a Numba CUDA stream, and may launch any number of kernels, allocate temporary device memory, or perform multi-step GPU computations.The wrapped function must have the signature:
func(input_1, input_2, ..., output_1, output_2, ..., stream)
where every
input_*andoutput_*is a Numba CUDA device array andstreamis a Numba CUDA stream obtained from XLA.- Parameters:
func (
Callable) – A Python function with the signature described above.outs (
ShapeDtype|Sequence[ShapeDtype]) – Output specification. A singlejax.ShapeDtypeStructor a sequence/pytree of them for multiple outputs.vmap_method (
str|None) – How to handlejax.vmap. Passed directly tojax.ffi.ffi_call.input_output_aliases (
dict[int,int] |None) – Mapping from input index to output index for in-place operations. Passed directly tojax.ffi.ffi_call.
- Returns:
A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside
jax.jit-compiled code.- Return type:
callable
- Raises:
ImportError – If Numba with CUDA support is not available.
TypeError – If func is not callable.
ValueError – If any input array is a 0-d (scalar) array, which is not supported by Numba CUDA device arrays.
See also
numba_cuda_kernelWrap a single
@cuda.jitkernel with fixed grid/block configuration.XLACustomKernel.def_numba_cuda_kernelRegister a Numba CUDA kernel with an
XLACustomKernel.
Notes
Each call to the returned function registers a new FFI target. For performance-critical inner loops, consider caching the returned callable.
Scalar (0-d) inputs are not supported because Numba CUDA cannot create device arrays from 0-d buffers. Wrap scalar values in 1-d arrays (e.g.,
jnp.array([value])) before passing them.Examples
>>> from numba import cuda >>> import jax >>> import jax.numpy as jnp >>> >>> @cuda.jit ... def add_kernel(x, y, temp, n): ... i = cuda.grid(1) ... if i < n: ... temp[i] = x[i] + y[i] >>> >>> @cuda.jit ... def scale_kernel(temp, out, scale, n): ... i = cuda.grid(1) ... if i < n: ... out[i] = temp[i] * scale >>> >>> def my_op(x, y, out, stream): ... n = x.shape[0] ... temp = cuda.device_array(n, dtype=x.dtype) ... threads = 256 ... blocks = (n + threads - 1) // threads ... add_kernel[blocks, threads, stream](x, y, temp, n) ... scale_kernel[blocks, threads, stream](temp, out, 2.0, n) >>> >>> fn = numba_cuda_callable( ... my_op, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... ) >>> >>> @jax.jit ... def f(a, b): ... return fn(a, b)