brainevent.numba_kernel#
- brainevent.numba_kernel(kernel, outs, *, vmap_method=None, input_output_aliases=None)[source]#
Create a JAX-callable function from a Numba CPU kernel.
Wraps a Numba JIT-compiled CPU kernel (decorated with
@numba.njit) so it can be called from JAX on CPU. The kernel operates on memory directly through the XLA FFI (Foreign Function Interface).- Parameters:
kernel (callable) – A Numba CPU kernel function decorated with
@numba.njit.outs (
ShapeDtype|Sequence[ShapeDtype]) – Output specification. A single struct for one output, or a sequence for multiple outputs.vmap_method (
str|None) – The method to use for vmapping this kernel. See JAX documentation forjax.ffi.ffi_call. Default isNone.input_output_aliases (
dict[int,int] |None) – Mapping from input indices to output indices for in-place operations. Default isNone.
- Returns:
A function that takes JAX arrays as inputs and returns JAX arrays as outputs. Compatible with
jax.jitand other transformations.- Return type:
callable
Examples
>>> import numba >>> import jax.numpy as jnp >>> import jax >>> >>> @numba.njit ... def add_kernel(x, y, out): ... for i in range(out.size): ... out[i] = x[i] + y[i] >>> >>> kernel = numba_kernel( ... add_kernel, ... outs=jax.ShapeDtypeStruct((64,), jnp.float32), ... ) >>> >>> a = jnp.arange(64, dtype=jnp.float32) >>> b = jnp.ones(64, dtype=jnp.float32) >>> result = kernel(a, b) >>> >>> # Multiple outputs >>> @numba.njit ... def split_kernel(x, out1, out2): ... for i in range(out1.size): ... out1[i] = x[i] * 2 ... out2[i] = x[i] * 3 >>> >>> kernel = numba_kernel( ... split_kernel, ... outs=[ ... jax.ShapeDtypeStruct((64,), jnp.float32), ... jax.ShapeDtypeStruct((64,), jnp.float32), ... ], ... ) >>> out1, out2 = kernel(x) >>> >>> # Use with jax.jit >>> @jax.jit ... def f(a, b): ... return kernel(a, b) >>> >>> # Use parallel Numba >>> @numba.njit(parallel=True) ... def parallel_add_kernel(x, y, out): ... for i in numba.prange(out.size): ... out[i] = x[i] + y[i]
- Raises:
ImportError – If Numba is not installed.
AssertionError – If kernel is not a Numba CPU dispatcher.
Note
The Numba kernel function should: - Accept input arrays followed by output arrays as arguments - Write results directly to the output arrays - Not return any values (outputs are written in-place)