brainevent.numba_kernel

Contents

brainevent.numba_kernel#

brainevent.numba_kernel(kernel, outs, *, vmap_method=None, input_output_aliases=None)[source]#

Create a JAX-callable function from a Numba CPU kernel.

Wraps a Numba JIT-compiled CPU kernel (decorated with @numba.njit) so it can be called from JAX on CPU. The kernel operates on memory directly through the XLA FFI (Foreign Function Interface).

Parameters:
  • kernel (callable) – A Numba CPU kernel function decorated with @numba.njit.

  • outs (ShapeDtype | Sequence[ShapeDtype]) – Output specification. A single struct for one output, or a sequence for multiple outputs.

  • vmap_method (str | None) – The method to use for vmapping this kernel. See JAX documentation for jax.ffi.ffi_call. Default is None.

  • input_output_aliases (dict[int, int] | None) – Mapping from input indices to output indices for in-place operations. Default is None.

Returns:

A function that takes JAX arrays as inputs and returns JAX arrays as outputs. Compatible with jax.jit and other transformations.

Return type:

callable

Examples

>>> import numba
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> @numba.njit
... def add_kernel(x, y, out):
...     for i in range(out.size):
...         out[i] = x[i] + y[i]
>>>
>>> kernel = numba_kernel(
...     add_kernel,
...     outs=jax.ShapeDtypeStruct((64,), jnp.float32),
... )
>>>
>>> a = jnp.arange(64, dtype=jnp.float32)
>>> b = jnp.ones(64, dtype=jnp.float32)
>>> result = kernel(a, b)
>>>
>>> # Multiple outputs
>>> @numba.njit
... def split_kernel(x, out1, out2):
...     for i in range(out1.size):
...         out1[i] = x[i] * 2
...         out2[i] = x[i] * 3
>>>
>>> kernel = numba_kernel(
...     split_kernel,
...     outs=[
...         jax.ShapeDtypeStruct((64,), jnp.float32),
...         jax.ShapeDtypeStruct((64,), jnp.float32),
...     ],
... )
>>> out1, out2 = kernel(x)
>>>
>>> # Use with jax.jit
>>> @jax.jit
... def f(a, b):
...     return kernel(a, b)
>>>
>>> # Use parallel Numba
>>> @numba.njit(parallel=True)
... def parallel_add_kernel(x, y, out):
...     for i in numba.prange(out.size):
...         out[i] = x[i] + y[i]
Raises:

Note

The Numba kernel function should: - Accept input arrays followed by output arrays as arguments - Write results directly to the output arrays - Not return any values (outputs are written in-place)