Source code for brainevent._op.numba_cuda_ffi

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import ctypes
import importlib.util
import threading
import traceback
from ctypes import c_void_p, c_size_t, POINTER, CFUNCTYPE, Structure
from typing import Callable, Dict, Tuple, Union

import jax
import numpy as np

from .numba_ffi import (
    XLA_FFI_Extension_Type,
    XLA_FFI_Extension_Base,
    XLA_FFI_Metadata_Extension,
    XLA_FFI_CallFrame,
    XLA_FFI_Buffer,
    _XLA_FFI_DTYPE_TO_NUMPY,
    _normalize_shapes_and_dtypes,
)
from .util import OutType, abstract_arguments

__all__ = [
    'numba_cuda_kernel',
    'numba_cuda_callable',
]

numba_cuda_installed = importlib.util.find_spec('numba') is not None

# Cached lazy import, initialized by import_numba_cuda() on first use.
cuda = None


def import_numba_cuda():
    """Import numba.cuda lazily and validate CUDA availability."""
    global cuda, numba_cuda_installed
    if cuda is not None:
        return cuda
    if not numba_cuda_installed:
        raise ImportError(
            'Numba with CUDA support is required. '
            'Please install numba and ensure CUDA is available.'
        )
    try:
        from numba import cuda as _cuda
        if not _cuda.is_available():
            numba_cuda_installed = False
            raise ImportError(
                'Numba with CUDA support is required. '
                'Please install numba and ensure CUDA is available.'
            )
    except Exception as exc:
        numba_cuda_installed = False
        raise ImportError(
            'Numba with CUDA support is required. '
            'Please install numba and ensure CUDA is available.'
        ) from exc
    cuda = _cuda
    return cuda

_NUMBA_CUDA_FFI_HANDLES: Dict[str, object] = {}
_CUDA_FFI_CALLBACK_COUNTER = 0
_CUDA_FFI_CALLBACK_LOCK = threading.Lock()

# The typed FFI callback signature: void* fn(XLA_FFI_CallFrame*)
_CUDA_FFI_CALLBACK_TYPE = CFUNCTYPE(c_void_p, POINTER(XLA_FFI_CallFrame))


# ---------------------------------------------------------------------------
# XLA FFI API structures for CUDA stream extraction
# (Based on XLA's C API: xla/ffi/api/c_api.h)
# ---------------------------------------------------------------------------

class XLA_FFI_Api_Version(Structure):
    """XLA FFI API version structure.

    Mirrors the ``XLA_FFI_Api_Version`` struct from XLA's C API header
    (``xla/ffi/api/c_api.h``).  Used internally to interpret version
    information from the XLA FFI API pointer.
    """
    _fields_ = [
        ("struct_size", c_size_t),
        ("extension_start", POINTER(XLA_FFI_Extension_Base)),
        ("major_version", ctypes.c_int),
        ("minor_version", ctypes.c_int),
    ]


class XLA_FFI_Stream_Get_Args(Structure):
    """Arguments for the ``XLA_FFI_Stream_Get`` function.

    Mirrors the ``XLA_FFI_Stream_Get_Args`` struct from XLA's C API.
    The ``ctx`` field is set to the execution context from the call
    frame, and ``stream`` is populated by XLA with the active CUDA
    stream pointer upon return.
    """
    _fields_ = [
        ("struct_size", c_size_t),
        ("extension_start", POINTER(XLA_FFI_Extension_Base)),
        ("ctx", c_void_p),  # XLA_FFI_ExecutionContext*
        ("stream", c_void_p),  # Output: cudaStream_t
    ]


# Function pointer type: XLA_FFI_Error* (*XLA_FFI_Stream_Get)(XLA_FFI_Stream_Get_Args*)
_XLA_FFI_Stream_Get_Fn = CFUNCTYPE(c_void_p, POINTER(XLA_FFI_Stream_Get_Args))


class XLA_FFI_Api(Structure):
    """Partial mirror of the ``XLA_FFI_Api`` structure from XLA's C API.

    Only fields up to and including ``XLA_FFI_Stream_Get`` are declared;
    later fields are not needed for CUDA stream extraction and are
    omitted.

    See Also
    --------
    _get_stream_from_callframe : Uses this structure to extract the
        CUDA stream from an XLA call frame.
    """
    _fields_ = [
        ("struct_size", c_size_t),
        ("extension_start", POINTER(XLA_FFI_Extension_Base)),
        ("api_version", XLA_FFI_Api_Version),  # Embedded struct, not pointer
        ("internal_api", c_void_p),
        ("XLA_FFI_Error_Create", c_void_p),
        ("XLA_FFI_Error_GetMessage", c_void_p),
        ("XLA_FFI_Error_Destroy", c_void_p),
        ("XLA_FFI_Handler_Register", c_void_p),
        ("XLA_FFI_Stream_Get", _XLA_FFI_Stream_Get_Fn),
        # ... other fields not needed for stream extraction
    ]


def _get_stream_from_callframe(call_frame) -> int:
    """Extract the CUDA stream pointer from an XLA FFI call frame.

    Calls the ``XLA_FFI_Stream_Get`` function exposed by XLA's API
    pointer to obtain the ``cudaStream_t`` associated with the current
    execution context.

    Parameters
    ----------
    call_frame : XLA_FFI_CallFrame
        The call frame structure passed to the FFI callback by XLA.

    Returns
    -------
    int
        The CUDA stream pointer (``cudaStream_t``) as a Python integer.

    Notes
    -----
    This function is internal and should not be called directly.  It is
    used by :class:`NumbaCudaFfiHandler` and
    :class:`NumbaCudaCallableHandler` to obtain the XLA-managed CUDA
    stream for zero-copy kernel launches.
    """
    api = call_frame.api
    # Prepare stream get arguments
    stream_args = XLA_FFI_Stream_Get_Args()
    stream_args.struct_size = ctypes.sizeof(XLA_FFI_Stream_Get_Args)
    stream_args.extension_start = POINTER(XLA_FFI_Extension_Base)()
    stream_args.ctx = call_frame.ctx
    stream_args.stream = None

    # Call XLA's stream getter
    api_ptr = ctypes.cast(api, POINTER(XLA_FFI_Api))
    api_ptr.contents.XLA_FFI_Stream_Get(stream_args)

    return stream_args.stream


def _numba_stream_from_ptr(stream_ptr: int):
    """Create a Numba CUDA stream from a raw ``cudaStream_t`` pointer.

    Parameters
    ----------
    stream_ptr : int
        The ``cudaStream_t`` pointer as a Python integer (e.g.,
        obtained from :func:`_get_stream_from_callframe`).

    Returns
    -------
    numba.cuda.cudadrv.driver.Stream
        A Numba CUDA stream object wrapping the given pointer.  Kernel
        launches on this stream will execute on XLA's CUDA stream.
    """
    return import_numba_cuda().external_stream(stream_ptr)


def _device_array_from_buffer(data_ptr: int, shape: Tuple[int, ...], dtype: np.dtype):
    """Create a Numba CUDA device array from a raw device memory pointer.

    Uses the ``__cuda_array_interface__`` protocol for zero-copy access
    to device memory owned by XLA.

    Parameters
    ----------
    data_ptr : int
        The device memory pointer as a Python integer.
    shape : tuple of int
        The shape of the array.
    dtype : numpy.dtype
        The element data type.

    Returns
    -------
    numba.cuda.cudadrv.devicearray.DeviceNDArray
        A Numba CUDA device array that wraps the given device memory
        without copying.

    Notes
    -----
    The returned array does **not** own the underlying memory.  The
    caller must ensure that the memory remains valid for the lifetime
    of the array.
    """

    class DevicePointerWrapper:
        """Wrapper class that implements __cuda_array_interface__ protocol."""

        def __init__(self, ptr, arr_shape, arr_dtype):
            self._ptr = ptr
            self._shape = arr_shape
            self._dtype = arr_dtype

        @property
        def __cuda_array_interface__(self):
            return {
                'shape': self._shape,
                'typestr': self._dtype.str,
                'data': (self._ptr, False),  # (ptr, read_only)
                'version': 3,
            }

    wrapper = DevicePointerWrapper(data_ptr, shape, dtype)
    return import_numba_cuda().as_cuda_array(wrapper)


def _compute_launch_config(
    launch_dims: Union[int, Tuple[int, ...]],
    threads_per_block: int = 256,
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
    """Compute CUDA grid and block dimensions from total launch dimensions.

    Automatically determines an appropriate grid/block decomposition
    for 1-D, 2-D, or 3-D kernel launches given the total number of
    threads desired along each axis.

    Parameters
    ----------
    launch_dims : int or tuple of int
        Total number of threads to launch along each axis.  An ``int``
        is treated as a 1-D launch.  Tuples of length 2 or 3 produce
        2-D or 3-D launches respectively.
    threads_per_block : int, optional
        Maximum number of threads per block for 1-D launches.  Default
        is ``256``.  For 2-D and 3-D launches, fixed block sizes are
        used (16x16 and 8x8x4 respectively).

    Returns
    -------
    grid : tuple of int
        Grid dimensions (number of blocks per axis).
    block : tuple of int
        Block dimensions (number of threads per block per axis).

    Raises
    ------
    ValueError
        If *launch_dims* has more than 3 dimensions.

    Examples
    --------
    .. code-block:: python

        >>> grid, block = _compute_launch_config(1024)
        >>> grid
        (4,)
        >>> block
        (256,)

        >>> grid, block = _compute_launch_config((64, 64))
        >>> grid
        (4, 4)
        >>> block
        (16, 16)
    """
    if isinstance(launch_dims, int):
        launch_dims = (launch_dims,)

    if len(launch_dims) == 1:
        total = launch_dims[0]
        block = (min(threads_per_block, total),)
        grid = ((total + block[0] - 1) // block[0],)
    elif len(launch_dims) == 2:
        # For 2D, use a square-ish block
        block_x = min(16, launch_dims[0])
        block_y = min(16, launch_dims[1])
        block = (block_x, block_y)
        grid = (
            (launch_dims[0] + block[0] - 1) // block[0],
            (launch_dims[1] + block[1] - 1) // block[1],
        )
    elif len(launch_dims) == 3:
        # For 3D
        block_x = min(8, launch_dims[0])
        block_y = min(8, launch_dims[1])
        block_z = min(4, launch_dims[2])
        block = (block_x, block_y, block_z)
        grid = (
            (launch_dims[0] + block[0] - 1) // block[0],
            (launch_dims[1] + block[1] - 1) // block[1],
            (launch_dims[2] + block[2] - 1) // block[2],
        )
    else:
        raise ValueError(f"launch_dims must have 1-3 dimensions, got {len(launch_dims)}")

    return grid, block


class NumbaCudaFfiHandler:
    """Typed FFI handler that bridges XLA's typed FFI protocol to a single Numba CUDA kernel.

    This handler registers a single ``@cuda.jit`` kernel as an XLA FFI
    target with fixed grid and block dimensions.  When XLA invokes the
    FFI callback during execution, the handler extracts input/output
    device arrays and the CUDA stream from the call frame, then launches
    the kernel on that stream.

    Parameters
    ----------
    name : str
        Unique FFI target name used for registration with
        ``jax.ffi.register_ffi_target``.
    kernel : numba.cuda.compiler.CUDADispatcher
        The compiled Numba CUDA kernel (from ``@cuda.jit``).
    input_shapes : tuple of tuple of int
        Expected shapes of the input buffers.
    input_dtypes : tuple of numpy.dtype
        Expected data types of the input buffers.
    output_shapes : tuple of tuple of int
        Expected shapes of the output buffers.
    output_dtypes : tuple of numpy.dtype
        Expected data types of the output buffers.
    grid : tuple of int
        Grid dimensions for the kernel launch.
    block : tuple of int
        Block dimensions for the kernel launch.
    shared_mem : int, optional
        Dynamic shared memory size in bytes.  Default is ``0``.

    See Also
    --------
    numba_cuda_kernel : High-level API for creating a JAX-callable from
        a single Numba CUDA kernel.
    NumbaCudaCallableHandler : Handler for arbitrary multi-kernel Python
        callables.

    Notes
    -----
    The handler object must be kept alive (stored in a module-level
    dictionary) to prevent garbage collection of the ctypes callback,
    which would cause a segmentation fault when XLA tries to invoke it.
    """

    def __init__(
        self,
        name: str,
        kernel,
        input_shapes: Tuple[Tuple[int, ...], ...],
        input_dtypes: Tuple[np.dtype, ...],
        output_shapes: Tuple[Tuple[int, ...], ...],
        output_dtypes: Tuple[np.dtype, ...],
        grid: Tuple[int, ...],
        block: Tuple[int, ...],
        shared_mem: int = 0,
    ):
        self.name = name
        self.kernel = kernel
        self.input_shapes = input_shapes
        self.input_dtypes = input_dtypes
        self.output_shapes = output_shapes
        self.output_dtypes = output_dtypes
        self.grid = grid
        self.block = block
        self.shared_mem = shared_mem

        # Create the ctypes callback - must be stored as an attribute to prevent GC
        self._callback = _CUDA_FFI_CALLBACK_TYPE(self._ffi_callback)

        # Register as an FFI target for CUDA platform
        capsule = jax.ffi.pycapsule(ctypes.cast(self._callback, c_void_p).value)
        jax.ffi.register_ffi_target(name, capsule, platform="CUDA")

    def _ffi_callback(self, call_frame_ptr):
        """Typed FFI callback invoked by XLA during kernel execution.

        Extracts input and output device arrays from the call frame,
        obtains the CUDA stream, and launches the Numba CUDA kernel.
        Also handles XLA metadata extension queries (API version and
        traits).

        Parameters
        ----------
        call_frame_ptr : ctypes.POINTER(XLA_FFI_CallFrame)
            Pointer to the XLA FFI call frame.

        Returns
        -------
        None
            Returns ``None`` to indicate success to XLA.
        """
        try:
            call_frame = call_frame_ptr.contents

            # Check for metadata query extension
            ext_ptr = call_frame.extension_start
            if ext_ptr:
                ext = ext_ptr.contents
                if ext.type == int(XLA_FFI_Extension_Type.Metadata):
                    metadata_ext = ctypes.cast(
                        ext_ptr, POINTER(XLA_FFI_Metadata_Extension)
                    ).contents
                    metadata = metadata_ext.metadata.contents
                    metadata.api_version.major_version = 0
                    metadata.api_version.minor_version = 1
                    metadata.traits = 0  # not command-buffer-compatible
                    return None  # success

            # Extract input buffers as CUDA device arrays
            n_inputs = call_frame.args.size
            input_arrays = []
            for i in range(n_inputs):
                buf_ptr = ctypes.cast(
                    call_frame.args.args[i], POINTER(XLA_FFI_Buffer)
                ).contents
                shape = tuple(buf_ptr.dims[d] for d in range(buf_ptr.rank))
                dtype = _XLA_FFI_DTYPE_TO_NUMPY.get(buf_ptr.dtype, self.input_dtypes[i])
                input_arrays.append(_device_array_from_buffer(buf_ptr.data, shape, dtype))

            # Extract output buffers as CUDA device arrays
            n_outputs = call_frame.rets.size
            output_arrays = []
            for i in range(n_outputs):
                buf_ptr = ctypes.cast(
                    call_frame.rets.rets[i], POINTER(XLA_FFI_Buffer)
                ).contents
                shape = tuple(buf_ptr.dims[d] for d in range(buf_ptr.rank))
                dtype = _XLA_FFI_DTYPE_TO_NUMPY.get(buf_ptr.dtype, self.output_dtypes[i])
                output_arrays.append(_device_array_from_buffer(buf_ptr.data, shape, dtype))

            # Extract XLA's CUDA stream and launch kernel on it
            stream_ptr = _get_stream_from_callframe(call_frame)
            # Use XLA's stream - no synchronization needed
            stream = _numba_stream_from_ptr(stream_ptr)
            with _CUDA_FFI_CALLBACK_LOCK:
                self.kernel[self.grid, self.block, stream, self.shared_mem](*input_arrays, *output_arrays)

        except Exception:
            traceback.print_exc()

        return None  # success


def _register_numba_cuda_ffi_target(
    kernel,
    input_shapes: Tuple[Tuple[int, ...], ...],
    input_dtypes: Tuple[np.dtype, ...],
    output_shapes: Tuple[Tuple[int, ...], ...],
    output_dtypes: Tuple[np.dtype, ...],
    grid: Tuple[int, ...],
    block: Tuple[int, ...],
    shared_mem: int = 0,
):
    """Register a Numba CUDA kernel as an XLA typed FFI target.

    Creates a :class:`NumbaCudaFfiHandler` that wraps the kernel and
    registers it with ``jax.ffi.register_ffi_target``.  The handler is
    stored in a module-level dictionary to prevent garbage collection.

    Parameters
    ----------
    kernel : numba.cuda.compiler.CUDADispatcher
        The compiled Numba CUDA kernel (from ``@cuda.jit``).
    input_shapes : tuple of tuple of int
        Shapes of the input buffers.
    input_dtypes : tuple of numpy.dtype
        Data types of the input buffers.
    output_shapes : tuple of tuple of int
        Shapes of the output buffers.
    output_dtypes : tuple of numpy.dtype
        Data types of the output buffers.
    grid : tuple of int
        Grid dimensions for the kernel launch.
    block : tuple of int
        Block dimensions for the kernel launch.
    shared_mem : int, optional
        Dynamic shared memory size in bytes.  Default is ``0``.

    Returns
    -------
    target_name : str
        The unique FFI target name assigned to this kernel.
    out_types : tuple of jax.ShapeDtypeStruct
        Output type specifications for use with ``jax.ffi.ffi_call``.

    Raises
    ------
    ImportError
        If Numba with CUDA support is not available.

    See Also
    --------
    NumbaCudaFfiHandler : The handler class created by this function.
    numba_cuda_kernel : High-level user-facing API.
    """
    global _CUDA_FFI_CALLBACK_COUNTER

    import_numba_cuda()

    target_name = f'brainevent_numba_cuda_ffi_{_CUDA_FFI_CALLBACK_COUNTER}'
    _CUDA_FFI_CALLBACK_COUNTER += 1

    handler = NumbaCudaFfiHandler(
        name=target_name,
        kernel=kernel,
        input_shapes=input_shapes,
        input_dtypes=input_dtypes,
        output_shapes=output_shapes,
        output_dtypes=output_dtypes,
        grid=grid,
        block=block,
        shared_mem=shared_mem,
    )

    # Keep the handler alive to prevent GC of ctypes callback
    _NUMBA_CUDA_FFI_HANDLES[target_name] = handler

    out_types = tuple(
        jax.ShapeDtypeStruct(shape, dtype)
        for shape, dtype in zip(output_shapes, output_dtypes)
    )
    return target_name, out_types


[docs] def numba_cuda_kernel( kernel, outs: OutType, *, grid: Union[int, Tuple[int, ...], None] = None, block: Union[int, Tuple[int, ...], None] = None, launch_dims: Union[int, Tuple[int, ...], None] = None, threads_per_block: int = 256, shared_mem: int = 0, vmap_method: str | None = None, input_output_aliases: dict[int, int] | None = None, ): """Create a JAX-callable function from a single Numba CUDA kernel. Wraps a Numba CUDA kernel (decorated with ``@cuda.jit``) so that it can be called from JAX on GPU. The kernel operates on device memory directly with zero-copy access via XLA's typed FFI protocol. Either ``(grid, block)`` or ``launch_dims`` must be specified to configure the CUDA launch. When ``launch_dims`` is used, the grid and block dimensions are computed automatically. Parameters ---------- kernel : numba.cuda.compiler.CUDADispatcher A Numba CUDA kernel function decorated with ``@cuda.jit``. outs : OutType Output specification. A single ``jax.ShapeDtypeStruct`` or a sequence/pytree of them for multiple outputs. grid : int or tuple of int or None, optional Grid dimensions for the kernel launch. Must be specified together with *block*. Mutually exclusive with *launch_dims*. block : int or tuple of int or None, optional Block dimensions for the kernel launch. Must be specified together with *grid*. launch_dims : int or tuple of int or None, optional Total number of threads to launch. Grid and block are computed automatically. Mutually exclusive with *(grid, block)*. threads_per_block : int, optional Number of threads per block when using *launch_dims*. Default is ``256``. shared_mem : int, optional Dynamic shared memory size in bytes. Default is ``0``. vmap_method : str or None, optional Method to use for ``jax.vmap``. Passed directly to ``jax.ffi.ffi_call``. input_output_aliases : dict of int to int or None, optional Mapping from input index to output index for in-place operations. Passed directly to ``jax.ffi.ffi_call``. Returns ------- callable A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside ``jax.jit``-compiled code. Raises ------ ImportError If Numba with CUDA support is not available. ValueError If neither ``(grid, block)`` nor ``launch_dims`` is specified. AssertionError If *kernel* is not a ``numba.cuda.dispatcher.CUDADispatcher``. See Also -------- numba_cuda_callable : Wrap an arbitrary Python callable that launches multiple Numba CUDA kernels. XLACustomKernel.def_numba_cuda_kernel : Register a Numba CUDA kernel with an ``XLACustomKernel``. Notes ----- Each call to the returned function registers a **new** FFI target. For performance-critical inner loops, consider caching the returned callable. Examples -------- .. code-block:: python >>> from numba import cuda >>> import jax >>> import jax.numpy as jnp >>> >>> @cuda.jit ... def add_kernel(x, y, out): ... i = cuda.grid(1) ... if i < out.size: ... out[i] = x[i] + y[i] >>> >>> # Option 1: Explicit grid/block >>> kernel_fn = numba_cuda_kernel( ... add_kernel, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... grid=4, ... block=256, ... ) >>> >>> # Option 2: Auto grid from launch_dims >>> kernel_fn = numba_cuda_kernel( ... add_kernel, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... launch_dims=1024, ... ) >>> >>> @jax.jit ... def f(a, b): ... return kernel_fn(a, b) """ import_numba_cuda() from numba.cuda.dispatcher import CUDADispatcher # Validate kernel type assert isinstance(kernel, CUDADispatcher), ( f'The kernel must be a Numba CUDA JIT-compiled function (from @cuda.jit), ' f'but got {type(kernel).__name__}.' ) # Compute grid and block dimensions if grid is not None and block is not None: # Explicit grid/block specified if isinstance(grid, int): grid = (grid,) if isinstance(block, int): block = (block,) grid = tuple(grid) block = tuple(block) elif launch_dims is not None: # Compute from launch_dims grid, block = _compute_launch_config(launch_dims, threads_per_block) else: raise ValueError( "Either (grid, block) or launch_dims must be specified for kernel launch configuration." ) # Output information out_info, out_treedef = abstract_arguments(outs) output_shapes, output_dtypes = _normalize_shapes_and_dtypes( tuple(out.shape for out in out_info), tuple(out.dtype for out in out_info), 'output', ) def call(*ins): """Invoke the registered Numba CUDA kernel through XLA FFI. Parameters ---------- *ins : jax.Array Input arrays on GPU device. Returns ------- result Output array(s) matching the ``outs`` specification. """ # Input information in_info, _ = abstract_arguments(ins) input_shapes, input_dtypes = _normalize_shapes_and_dtypes( tuple(inp.shape for inp in in_info), tuple(inp.dtype for inp in in_info), 'input', ) # Register FFI target target_name, out_types = _register_numba_cuda_ffi_target( kernel, input_shapes, input_dtypes, output_shapes, output_dtypes, grid, block, shared_mem, ) # Call FFI with typed FFI protocol result = jax.ffi.ffi_call( target_name, out_types, input_output_aliases=input_output_aliases, vmap_method=vmap_method, )(*ins) return jax.tree.unflatten(out_treedef, result) return call
# =========================================================================== # numba_cuda_callable: Multi-kernel callable wrapper # =========================================================================== _NUMBA_CUDA_CALLABLE_HANDLES: Dict[str, object] = {} _CUDA_CALLABLE_CALLBACK_COUNTER = 0 _CUDA_CALLABLE_LOCK = threading.Lock() # The typed FFI callback signature: void* fn(XLA_FFI_CallFrame*) _CUDA_CALLABLE_CALLBACK_TYPE = CFUNCTYPE(c_void_p, POINTER(XLA_FFI_CallFrame)) class NumbaCudaCallableHandler: """Typed FFI handler for arbitrary Python callables that launch Numba CUDA kernels. Unlike :class:`NumbaCudaFfiHandler` (which wraps a **single** ``@cuda.jit`` kernel with a fixed grid/block), this handler invokes a plain Python function and passes it Numba device arrays together with a Numba CUDA stream so the function can launch an arbitrary number of kernels, allocate temporary device memory, and perform multi-step GPU computations. Parameters ---------- name : str Unique FFI target name for registration with ``jax.ffi.register_ffi_target``. func : callable The Python function to invoke. Its signature must be ``func(in1, in2, ..., out1, out2, ..., stream)`` where each ``in*`` and ``out*`` is a Numba CUDA device array and ``stream`` is a Numba CUDA stream. num_inputs : int Number of input buffers expected. num_outputs : int Number of output buffers expected. input_dtypes : tuple of numpy.dtype Expected data types of the input buffers. output_shapes : tuple of tuple of int Expected shapes of the output buffers. output_dtypes : tuple of numpy.dtype Expected data types of the output buffers. See Also -------- numba_cuda_callable : High-level API for creating a JAX-callable from an arbitrary Python function. NumbaCudaFfiHandler : Handler for a single Numba CUDA kernel. Notes ----- The handler object must be kept alive (stored in a module-level dictionary) to prevent garbage collection of the ctypes callback. """ def __init__( self, name: str, func: Callable, num_inputs: int, num_outputs: int, input_dtypes: Tuple[np.dtype, ...], output_shapes: Tuple[Tuple[int, ...], ...], output_dtypes: Tuple[np.dtype, ...], ): self.name = name self.func = func self.num_inputs = num_inputs self.num_outputs = num_outputs self.input_dtypes = input_dtypes self.output_shapes = output_shapes self.output_dtypes = output_dtypes # Create the ctypes callback -- must be kept alive to prevent GC self._callback = _CUDA_CALLABLE_CALLBACK_TYPE(self._ffi_callback) # Register as an FFI target for CUDA platform capsule = jax.ffi.pycapsule(ctypes.cast(self._callback, c_void_p).value) jax.ffi.register_ffi_target(name, capsule, platform="CUDA") def _ffi_callback(self, call_frame_ptr): """Typed FFI callback invoked by XLA during execution. Extracts input and output device arrays and the CUDA stream from the call frame, then calls the user-provided Python function. Also handles XLA metadata extension queries. Parameters ---------- call_frame_ptr : ctypes.POINTER(XLA_FFI_CallFrame) Pointer to the XLA FFI call frame. Returns ------- None Returns ``None`` to indicate success to XLA. """ try: call_frame = call_frame_ptr.contents # Handle metadata extension query (API version / traits) ext_ptr = call_frame.extension_start if ext_ptr: ext = ext_ptr.contents if ext.type == int(XLA_FFI_Extension_Type.Metadata): metadata_ext = ctypes.cast( ext_ptr, POINTER(XLA_FFI_Metadata_Extension) ).contents metadata = metadata_ext.metadata.contents metadata.api_version.major_version = 0 metadata.api_version.minor_version = 1 metadata.traits = 0 # not command-buffer-compatible return None # success # Extract input buffers as Numba CUDA device arrays n_inputs = call_frame.args.size input_arrays = [] for i in range(n_inputs): buf_ptr = ctypes.cast( call_frame.args.args[i], POINTER(XLA_FFI_Buffer) ).contents shape = tuple(buf_ptr.dims[d] for d in range(buf_ptr.rank)) dtype = _XLA_FFI_DTYPE_TO_NUMPY.get(buf_ptr.dtype) if dtype is None and i < len(self.input_dtypes): dtype = self.input_dtypes[i] elif dtype is None: dtype = np.dtype(np.float32) input_arrays.append(_device_array_from_buffer(buf_ptr.data, shape, dtype)) # Extract output buffers as Numba CUDA device arrays n_outputs = call_frame.rets.size output_arrays = [] for i in range(n_outputs): buf_ptr = ctypes.cast( call_frame.rets.rets[i], POINTER(XLA_FFI_Buffer) ).contents shape = tuple(buf_ptr.dims[d] for d in range(buf_ptr.rank)) dtype = _XLA_FFI_DTYPE_TO_NUMPY.get(buf_ptr.dtype) if dtype is None and i < len(self.output_dtypes): dtype = self.output_dtypes[i] elif dtype is None: dtype = np.dtype(np.float32) output_arrays.append(_device_array_from_buffer(buf_ptr.data, shape, dtype)) # Extract XLA's CUDA stream and create Numba stream wrapper stream_ptr = _get_stream_from_callframe(call_frame) stream = _numba_stream_from_ptr(stream_ptr) # Call the user function # Signature: func(in1, in2, ..., out1, out2, ..., stream) with _CUDA_CALLABLE_LOCK: self.func(*input_arrays, *output_arrays, stream) except Exception: traceback.print_exc() return None # success def _register_numba_cuda_callable_target( func: Callable, num_inputs: int, num_outputs: int, input_dtypes: Tuple[np.dtype, ...], output_shapes: Tuple[Tuple[int, ...], ...], output_dtypes: Tuple[np.dtype, ...], ): """Register a Python callable as an XLA typed FFI target for CUDA. Creates a :class:`NumbaCudaCallableHandler` and registers it with ``jax.ffi.register_ffi_target``. The handler is stored in a module-level dictionary to prevent garbage collection. Parameters ---------- func : callable The Python function to wrap. Its signature must be ``func(in1, ..., out1, ..., stream)``. num_inputs : int Number of input buffers. num_outputs : int Number of output buffers. input_dtypes : tuple of numpy.dtype Data types of the input buffers. output_shapes : tuple of tuple of int Shapes of the output buffers. output_dtypes : tuple of numpy.dtype Data types of the output buffers. Returns ------- target_name : str The unique FFI target name assigned to this callable. out_types : tuple of jax.ShapeDtypeStruct Output type specifications for use with ``jax.ffi.ffi_call``. Raises ------ ImportError If Numba with CUDA support is not available. See Also -------- NumbaCudaCallableHandler : The handler class created by this function. numba_cuda_callable : High-level user-facing API. """ global _CUDA_CALLABLE_CALLBACK_COUNTER import_numba_cuda() target_name = f'brainevent_numba_cuda_callable_{_CUDA_CALLABLE_CALLBACK_COUNTER}' _CUDA_CALLABLE_CALLBACK_COUNTER += 1 handler = NumbaCudaCallableHandler( name=target_name, func=func, num_inputs=num_inputs, num_outputs=num_outputs, input_dtypes=input_dtypes, output_shapes=output_shapes, output_dtypes=output_dtypes, ) # Keep the handler alive to prevent GC of the ctypes callback _NUMBA_CUDA_CALLABLE_HANDLES[target_name] = handler out_types = tuple( jax.ShapeDtypeStruct(shape, dtype) for shape, dtype in zip(output_shapes, output_dtypes) ) return target_name, out_types
[docs] def numba_cuda_callable( func: Callable, outs: OutType, *, vmap_method: str | None = None, input_output_aliases: dict[int, int] | None = None, ): """Create a JAX-callable from a Python function that launches Numba CUDA kernels. Unlike :func:`numba_cuda_kernel` (which wraps a single ``@cuda.jit`` kernel), this function wraps an **arbitrary** Python callable. The callable receives Numba CUDA device arrays for inputs and outputs, plus a Numba CUDA stream, and may launch any number of kernels, allocate temporary device memory, or perform multi-step GPU computations. The wrapped function must have the signature:: func(input_1, input_2, ..., output_1, output_2, ..., stream) where every ``input_*`` and ``output_*`` is a Numba CUDA device array and ``stream`` is a Numba CUDA stream obtained from XLA. Parameters ---------- func : callable A Python function with the signature described above. outs : OutType Output specification. A single ``jax.ShapeDtypeStruct`` or a sequence/pytree of them for multiple outputs. vmap_method : str or None, optional How to handle ``jax.vmap``. Passed directly to ``jax.ffi.ffi_call``. input_output_aliases : dict of int to int or None, optional Mapping from input index to output index for in-place operations. Passed directly to ``jax.ffi.ffi_call``. Returns ------- callable A function that takes JAX arrays as inputs and returns JAX arrays as outputs. The function can be used inside ``jax.jit``-compiled code. Raises ------ ImportError If Numba with CUDA support is not available. TypeError If *func* is not callable. ValueError If any input array is a 0-d (scalar) array, which is not supported by Numba CUDA device arrays. See Also -------- numba_cuda_kernel : Wrap a single ``@cuda.jit`` kernel with fixed grid/block configuration. XLACustomKernel.def_numba_cuda_kernel : Register a Numba CUDA kernel with an ``XLACustomKernel``. Notes ----- Each call to the returned function registers a new FFI target. For performance-critical inner loops, consider caching the returned callable. Scalar (0-d) inputs are not supported because Numba CUDA cannot create device arrays from 0-d buffers. Wrap scalar values in 1-d arrays (e.g., ``jnp.array([value])``) before passing them. Examples -------- .. code-block:: python >>> from numba import cuda >>> import jax >>> import jax.numpy as jnp >>> >>> @cuda.jit ... def add_kernel(x, y, temp, n): ... i = cuda.grid(1) ... if i < n: ... temp[i] = x[i] + y[i] >>> >>> @cuda.jit ... def scale_kernel(temp, out, scale, n): ... i = cuda.grid(1) ... if i < n: ... out[i] = temp[i] * scale >>> >>> def my_op(x, y, out, stream): ... n = x.shape[0] ... temp = cuda.device_array(n, dtype=x.dtype) ... threads = 256 ... blocks = (n + threads - 1) // threads ... add_kernel[blocks, threads, stream](x, y, temp, n) ... scale_kernel[blocks, threads, stream](temp, out, 2.0, n) >>> >>> fn = numba_cuda_callable( ... my_op, ... outs=jax.ShapeDtypeStruct((1024,), jnp.float32), ... ) >>> >>> @jax.jit ... def f(a, b): ... return fn(a, b) """ import_numba_cuda() if not callable(func): raise TypeError( f'func must be callable, but got {type(func).__name__}.' ) # Output information out_info, out_treedef = abstract_arguments(outs) output_shapes, output_dtypes = _normalize_shapes_and_dtypes( tuple(out.shape for out in out_info), tuple(out.dtype for out in out_info), 'output', ) num_outputs = len(out_info) def call(*inputs): """Invoke the registered callable through XLA FFI. Parameters ---------- *inputs : jax.Array Input arrays on GPU device. Returns ------- result Output array(s) matching the ``outs`` specification. """ inputs = jax.tree.map(jax.numpy.array, inputs) # Reject scalar (0-d) inputs — Numba CUDA kernels cannot operate on 0-d device arrays for i, inp in enumerate(jax.tree.leaves(inputs)): if jax.numpy.ndim(inp) == 0: raise ValueError( f"numba_cuda_callable does not support 0-d (scalar) array inputs, " f"but input {i} has shape (). " f"Wrap scalars in a 1-d array, e.g. jnp.array([value])." ) # -- collect input metadata -------------------------------------------- in_info, _ = abstract_arguments(inputs) input_dtypes = tuple(np.dtype(inp.dtype) for inp in in_info) # -- register the FFI target ------------------------------------------- target_name, out_types = _register_numba_cuda_callable_target( func, num_inputs=len(inputs), num_outputs=num_outputs, input_dtypes=input_dtypes, output_shapes=output_shapes, output_dtypes=output_dtypes, ) # -- invoke via jax.ffi.ffi_call --------------------------------------- result = jax.ffi.ffi_call( target_name, out_types, input_output_aliases=input_output_aliases, vmap_method=vmap_method, )(*inputs) return jax.tree.unflatten(out_treedef, result) return call