Custom CPU Operators with Numba#

This tutorial shows how to write high-performance CPU kernels using Numba’s @njit decorator and integrate them into the BrainEvent / JAX ecosystem.

Numba compiles Python functions to native machine code via LLVM, achieving speeds comparable to C/Fortran. BrainEvent’s numba_kernel function bridges Numba JIT-compiled functions into JAX via XLA’s Foreign Function Interface (FFI), so your Numba kernels become first-class JAX operations compatible with jax.jit, jax.vmap, and other transforms.

Contents#

  1. Why Numba on CPU?

  2. Installation and Imports

  3. Writing Numba JIT Kernels (@numba.njit)

  4. numba_kernel – Wrapping for JAX

  5. Parallel Kernels with numba.prange

  6. Multiple Inputs and Outputs

  7. Registering with XLACustomKernel

  8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication

  9. Combining Numba CPU and Warp/Numba-CUDA Backends

  10. Summary

1. Why Numba on CPU?#

JAX runs on CPU, GPU, and TPU but some algorithms do not map well to the GPU’s massively-parallel execution model:

  • Sparse / irregular access patterns – random memory accesses serialize on GPU

  • Sequential algorithms – recurrences that depend on previous iterations

  • Small to medium problem sizes – GPU overhead dominates for small arrays

  • CPU-only environments – laptops, CI servers, edge devices

Property

JAX native (CPU)

Numba (@njit)

C extension

JIT speed

XLA (fast)

LLVM (fast)

Compiled ahead of time

Python overhead

Yes

Eliminated

Eliminated

Parallelism

Limited

prange / OpenMP

pthread / OpenMP

Custom loop structure

No

Yes

Yes

Write in Python

Yes

Yes

No

Numba @njit lets you write the inner loop in Python while achieving native performance, and brainevent.numba_kernel makes the result a proper JAX primitive.

Requirements: pip install numba (no GPU needed)

2. Installation and Imports#

# Install if needed:
# !pip install numba -U
# !pip install brainevent -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, numba_kernel

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    import numba
    print(f"Numba version  : {numba.__version__}")
    NUMBA_AVAILABLE = True
except ImportError:
    print("Numba not installed. Run: pip install numba")
    NUMBA_AVAILABLE = False

3. Writing Numba JIT Kernels (@numba.njit)#

Rules for Numba CPU kernels used with numba_kernel:

  • Decorate with @numba.njit (or @numba.njit(parallel=True) for parallelism)

  • Function signature: kernel(input1, input2, ..., output1, output2, ...) – inputs first, then outputs; all as NumPy arrays (zero-copy from JAX)

  • Write results into output arrays – no return values

  • Standard Python math, NumPy slicing, and for loops all work

3.1 Simple Element-wise Kernels#

if NUMBA_AVAILABLE:
    @numba.njit
    def add_kernel(x, y, out):
        """out[i] = x[i] + y[i]"""
        for i in range(out.size):
            out[i] = x[i] + y[i]

    @numba.njit
    def relu_kernel(x, out):
        """out[i] = max(x[i], 0.0)"""
        for i in range(out.size):
            v = x[i]
            out[i] = v if v > 0.0 else 0.0

    @numba.njit
    def matvec_kernel(A, x, out):
        """Dense matrix-vector product: out = A @ x"""
        rows, cols = A.shape
        for i in range(rows):
            total = A.dtype.type(0)
            for j in range(cols):
                total += A[i, j] * x[j]
            out[i] = total

    print("Numba kernels defined:", add_kernel, relu_kernel, matvec_kernel)

3.2 Reduction Kernels#

if NUMBA_AVAILABLE:
    @numba.njit
    def sum_kernel(x, out):
        """out[0] = sum(x)."""
        total = x.dtype.type(0)
        for i in range(x.size):
            total += x[i]
        out[0] = total

    @numba.njit
    def max_kernel(x, out):
        """out[0] = max(x)."""
        m = x[0]
        for i in range(1, x.size):
            if x[i] > m:
                m = x[i]
        out[0] = m

    @numba.njit
    def running_stats_kernel(x, mean_out, std_out):
        """Compute mean and std in a single pass."""
        n = x.size
        s = x.dtype.type(0)
        for i in range(n):
            s += x[i]
        mean = s / n
        var = x.dtype.type(0)
        for i in range(n):
            d = x[i] - mean
            var += d * d
        mean_out[0] = mean
        std_out[0]  = (var / n) ** 0.5

    print("Reduction kernels defined.")

4. numba_kernel – Wrapping for JAX#

numba_kernel wraps a Numba CPU kernel so it can be called with JAX CPU arrays via XLA’s typed FFI protocol.

Signature:

numba_kernel(
    kernel,              # @numba.njit function
    outs,                # jax.ShapeDtypeStruct or list thereof
    *,
    vmap_method=None,
    input_output_aliases=None,
) -> callable

The returned callable accepts JAX arrays as inputs and returns JAX arrays as outputs. It is compatible with jax.jit.

if NUMBA_AVAILABLE:
    N = 512
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32) * 3.0

    # Create the JAX-callable wrapper
    add_fn = numba_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    result = add_fn(a, b)
    # numba_kernel returns a tuple; unwrap if needed
    result = result[0] if isinstance(result, tuple) else result

    expected = a + b
    print("Add max error  :", float(jnp.max(jnp.abs(result - expected))))

    # ---- ReLU ----
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)
    relu_fn = numba_kernel(
        relu_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )
    r = relu_fn(x)
    r = r[0] if isinstance(r, tuple) else r
    print("ReLU max error :", float(jnp.max(jnp.abs(r - jnp.maximum(x, 0.0)))))
if NUMBA_AVAILABLE:
    # ---- Reduction ----
    N = 10_000
    x = jnp.arange(N, dtype=jnp.float32)

    sum_fn = numba_kernel(
        sum_kernel,
        outs=jax.ShapeDtypeStruct((1,), jnp.float32),
    )

    s = sum_fn(x)
    s = s[0] if isinstance(s, tuple) else s
    print(f"Sum: {float(s[0]):.1f}  |  Expected: {float(jnp.sum(x)):.1f}")

    # ---- Multiple outputs (mean and std in one pass) ----
    stats_fn = numba_kernel(
        running_stats_kernel,
        outs=[
            jax.ShapeDtypeStruct((1,), jnp.float32),  # mean
            jax.ShapeDtypeStruct((1,), jnp.float32),  # std
        ],
    )

    mean_val, std_val = stats_fn(x)
    print(f"Mean: {float(mean_val[0]):.2f}  |  Std: {float(std_val[0]):.2f}")
    print(f"jnp.mean: {float(jnp.mean(x)):.2f}  |  jnp.std: {float(jnp.std(x)):.2f}")

4.1 JIT Compatibility#

if NUMBA_AVAILABLE:
    N = 128
    add_fn_cached = numba_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    @jax.jit
    def jitted_pipeline(a, b):
        # Mix Numba kernel with standard JAX operations
        temp = add_fn_cached(a, b)
        temp = temp[0] if isinstance(temp, tuple) else temp
        return jnp.sin(temp) * jnp.sqrt(jnp.abs(temp) + 1.0)

    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32)

    r1 = jitted_pipeline(a, b)
    r2 = jitted_pipeline(a * 2, b * 0.5)   # second call reuses compiled code

    print("JIT pipeline output shape:", r1.shape)
    print("First 5 values           :", r1[:5])

5. Parallel Kernels with numba.prange#

Adding parallel=True to @numba.njit and replacing range with numba.prange enables automatic parallelization across CPU cores using threading.

This is the easiest way to exploit multi-core CPUs without writing thread management code.

if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def parallel_add_kernel(x, y, out):
        """Parallel element-wise add using prange."""
        for i in numba.prange(out.size):  # parallelized loop
            out[i] = x[i] + y[i]

    @numba.njit(parallel=True)
    def parallel_matvec_kernel(A, x, out):
        """Parallel matrix-vector product: each row computed by a separate thread."""
        rows, cols = A.shape
        for i in numba.prange(rows):    # parallelize over rows
            total = A.dtype.type(0)
            for j in range(cols):       # inner loop stays sequential
                total += A[i, j] * x[j]
            out[i] = total

    @numba.njit(parallel=True)
    def parallel_exp_decay_kernel(trace, spikes, tau_inv, out):
        """
        Exponential trace update used in STDP:
          out[i] = trace[i] * exp(-tau_inv) + spikes[i]
        """
        import math
        decay = math.exp(-tau_inv[0])
        for i in numba.prange(out.size):
            out[i] = trace[i] * decay + spikes[i]

    print("Parallel kernels defined.")
if NUMBA_AVAILABLE:
    import time

    N = 1_000_000
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32)

    serial_fn   = numba_kernel(add_kernel,          outs=jax.ShapeDtypeStruct((N,), jnp.float32))
    parallel_fn = numba_kernel(parallel_add_kernel, outs=jax.ShapeDtypeStruct((N,), jnp.float32))

    # Warm up
    jax.block_until_ready(serial_fn(a, b))
    jax.block_until_ready(parallel_fn(a, b))

    N_TRIALS = 20

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(serial_fn(a, b))
    serial_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(parallel_fn(a, b))
    parallel_time = (time.time() - t0) / N_TRIALS * 1000

    import os
    n_cores = os.cpu_count()
    print(f"N = {N:,}  |  CPU cores: {n_cores}")
    print(f"Serial   : {serial_time:.2f} ms")
    print(f"Parallel : {parallel_time:.2f} ms")
    print(f"Speedup  : {serial_time / parallel_time:.2f}x")

6. Multiple Inputs and Outputs#

Numba kernels can take any number of inputs and outputs. The outs argument to numba_kernel mirrors the output buffers: a single ShapeDtypeStruct for one output, a list for multiple.

if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def lif_dynamics_kernel(
        V,         # membrane potentials  (N,)
        I_ext,     # external current     (N,)
        tau_inv,   # 1/tau_m              (1,)
        V_th,      # threshold            (1,)
        V_rest,    # reset potential      (1,)
        dt,        # time step            (1,)
        V_out,     # updated potentials   (N,)  – output
        spikes,    # spike vector         (N,)  – output
    ):
        """
        One Euler step of leaky integrate-and-fire dynamics:
          V_out[i] = V[i] + dt * (-(V[i] - V_rest) * tau_inv + I_ext[i])
        then threshold and reset.
        """
        th   = V_th[0]
        vr   = V_rest[0]
        ti   = tau_inv[0]
        step = dt[0]

        for i in numba.prange(V.size):
            v_new = V[i] + step * (-(V[i] - vr) * ti + I_ext[i])
            if v_new >= th:
                spikes[i]  = 1
                V_out[i]   = vr
            else:
                spikes[i]  = 0
                V_out[i]   = v_new

    print("LIF dynamics kernel defined.")
if NUMBA_AVAILABLE:
    N = 5_000
    rng = np.random.default_rng(0)

    V      = jnp.array(rng.uniform(-75.0, -50.0, N).astype(np.float32))
    I_ext  = jnp.array(rng.uniform(0.0,    5.0,  N).astype(np.float32))
    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau_m = 20 ms
    V_th   = jnp.array([-55.0], dtype=jnp.float32)
    V_rest = jnp.array([-70.0], dtype=jnp.float32)
    dt     = jnp.array([0.1],   dtype=jnp.float32)         # dt = 0.1 ms

    lif_fn = numba_kernel(
        lif_dynamics_kernel,
        outs=[
            jax.ShapeDtypeStruct((N,), jnp.float32),  # V_out
            jax.ShapeDtypeStruct((N,), jnp.int32),    # spikes
        ],
    )

    V_new, spikes = lif_fn(V, I_ext, tau_inv, V_th, V_rest, dt)

    print(f"Neurons: {N}")
    print(f"Spikes : {int(spikes.sum())} ({100*float(spikes.mean()):.1f}%)")
    print(f"V range: [{float(V_new.min()):.2f}, {float(V_new.max()):.2f}] mV")

    # Verify against JAX reference
    V_ref = V + dt[0] * (-(V - V_rest[0]) * tau_inv[0] + I_ext)
    spk_ref = (V_ref >= V_th[0]).astype(jnp.int32)
    V_ref   = jnp.where(spk_ref, V_rest[0], V_ref)

    print(f"V max error    : {float(jnp.max(jnp.abs(V_new - V_ref))):.6f} mV")
    print(f"Spike mismatch : {int(jnp.sum(spikes != spk_ref))}")

7. Registering with XLACustomKernel#

For production use, embed your Numba kernel inside a kernel generator and register it with XLACustomKernel. The generator receives shape/dtype information forwarded from primitive.bind and returns the concrete callable.

if NUMBA_AVAILABLE:
    # -----------------------------------------------------------------------
    # Kernel generator: exponential trace update (STDP)
    #   out[i] = trace[i] * decay + spikes[i]
    # -----------------------------------------------------------------------

    def exp_trace_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]

        @numba.njit(parallel=True)
        def trace_kern(trace, spikes, tau_inv, out):
            import math
            decay = math.exp(-tau_inv[0])
            for i in numba.prange(n):
                out[i] = trace[i] * decay + spikes[i]

        def kernel(trace, spikes, tau_inv):
            result = numba_kernel(
                trace_kern,
                outs=out_info,
            )(trace, spikes, tau_inv)
            return result if not isinstance(result, tuple) else result

        return kernel

    # Register the primitive
    trace_op = XLACustomKernel('tutorial_numba_exp_trace')
    trace_op.def_numba_kernel(exp_trace_numba_generator)

    print("Registered backends:", list(trace_op._kernels.keys()))
if NUMBA_AVAILABLE:
    N = 1000
    trace   = jnp.zeros(N, dtype=jnp.float32)
    spikes  = jnp.array(np.random.default_rng(1).random(N) < 0.1,
                        dtype=jnp.float32)
    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau = 20 ms

    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)

    @jax.jit
    def update_trace(trace, spikes, tau_inv):
        return trace_op(
            trace, spikes, tau_inv,
            outs=[out_spec],
        )[0]

    # Simulate 100 time steps of trace dynamics
    import math
    decay = math.exp(-float(tau_inv[0]))

    trace_history = []
    for step in range(100):
        spikes = jnp.array(
            np.random.default_rng(step).random(N) < 0.05,
            dtype=jnp.float32
        )
        trace = update_trace(trace, spikes, tau_inv)
        trace_history.append(float(trace.mean()))

    print(f"Trace stats after 100 steps:")
    print(f"  Mean  : {float(trace.mean()):.4f}")
    print(f"  Max   : {float(trace.max()):.4f}")
    print(f"  Steady-state (theory): {0.05 / (1 - decay):.4f}")

8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication#

A core operation in neural network simulation: given a CSR weight matrix and a float input vector, compute the matrix-vector product.

This is naturally sequential per output neuron (row of CSR), making it a good fit for parallel Numba on CPU.

if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def csr_matvec_numba(
        data,       # CSR non-zero values  (nnz,)
        indices,    # CSR column indices   (nnz,)
        indptr,     # CSR row pointers     (n_rows+1,)
        x,          # input vector         (n_cols,)
        out,        # output vector        (n_rows,)
    ):
        """
        Sparse matrix-vector product (CSR format).
        Each row is processed by one thread (parallel over rows).
        """
        n_rows = indptr.size - 1
        for i in numba.prange(n_rows):
            total = out.dtype.type(0)
            for k in range(indptr[i], indptr[i + 1]):
                total += data[k] * x[indices[k]]
            out[i] = total

    def csr_mv_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(data, indices, indptr, x):
            result = numba_kernel(
                csr_matvec_numba,
                outs=out_info,
            )(data, indices, indptr, x)
            return result if not isinstance(result, tuple) else result

        return kernel

    csr_mv_op = XLACustomKernel('tutorial_numba_csr_matvec')
    csr_mv_op.def_numba_kernel(csr_mv_numba_generator)

    print("CSR MV operator registered.")
if NUMBA_AVAILABLE:
    import scipy.sparse as sp

    N_PRE  = 2000
    N_POST = 1000
    PROB   = 0.05

    rng   = np.random.default_rng(42)
    dense = (rng.random((N_POST, N_PRE)) < PROB).astype(np.float32)
    dense *= rng.uniform(0.01, 0.5, dense.shape).astype(np.float32)
    csr   = sp.csr_matrix(dense)

    data    = jnp.array(csr.data,    dtype=jnp.float32)
    indices = jnp.array(csr.indices, dtype=jnp.int32)
    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)
    x       = jnp.array(rng.random(N_PRE).astype(np.float32))

    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)

    result = csr_mv_op(
        data, indices, indptr, x,
        outs=[out_spec],
    )[0]

    expected = jnp.array(dense) @ x
    print(f"Network: {N_PRE} pre -> {N_POST} post  (nnz={csr.nnz})")
    print(f"Max error vs dense: {float(jnp.max(jnp.abs(result - expected))):.6f}")
if NUMBA_AVAILABLE:
    import time

    @jax.jit
    def numba_csr_mv(data, indices, indptr, x):
        return csr_mv_op(
            data, indices, indptr, x,
            outs=[out_spec],
        )[0]

    @jax.jit
    def jax_dense_mv(A, x):
        return A @ x

    A_jnp = jnp.array(dense)

    # Warm up
    jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
    jax.block_until_ready(jax_dense_mv(A_jnp, x))

    N_TRIALS = 50

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
    numba_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(jax_dense_mv(A_jnp, x))
    jax_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Numba CSR MV  : {numba_time:.2f} ms")
    print(f"JAX dense MV  : {jax_time:.2f} ms")
    print(f"Speedup       : {jax_time / numba_time:.2f}x  (sparsity: {1 - csr.nnz/(N_PRE*N_POST):.0%})")

9. Combining Numba CPU and GPU Backends#

The same XLACustomKernel primitive can have both a Numba CPU backend and a GPU backend (Warp or Numba CUDA). JAX automatically dispatches to the correct backend based on the device where the arrays live.

try:
    import warp
    from warp.jax_experimental import jax_kernel as warp_jax_kernel
    from brainevent import jaxinfo_to_warpinfo
    warp.config.quiet = True
    WARP_AVAILABLE = True
except ImportError:
    WARP_AVAILABLE = False

if NUMBA_AVAILABLE:
    # CPU backend (already shown above)
    @numba.njit(parallel=True)
    def scale_numba(x, s, out):
        for i in numba.prange(out.size):
            out[i] = x[i] * s[0]

    def scale_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(x, s):
            r = numba_kernel(scale_numba, outs=out_info)(x, s)
            return r if not isinstance(r, tuple) else r

        return kernel

    scale_op = XLACustomKernel('tutorial_multi_backend_scale')
    scale_op.def_numba_kernel(scale_numba_generator)   # CPU backend

    if WARP_AVAILABLE:
        def scale_warp_generator(**kwargs):
            out_info = kwargs['outs'][0]
            n = out_info.shape[0]
            t = jaxinfo_to_warpinfo(out_info)
            s_type = warp.array(dtype=jaxinfo_to_warpinfo(out_info).dtype, ndim=1)

            @warp.kernel
            def kern(x: t, s: s_type, out: t):
                i = warp.tid()
                out[i] = x[i] * s[0]

            def kernel(x, s):
                fn = warp_jax_kernel(kern, launch_dims=[n], num_outputs=1,
                                     output_dims={'out': (n,)})
                return fn(x, s)

            return kernel

        scale_op.def_warp_kernel(scale_warp_generator)  # GPU backend

    print("Multi-backend scale op registered.")
    print("Backends:", {p: list(b.keys()) for p, b in scale_op._kernels.items()})

    # Use it
    N = 256
    x = jnp.arange(N, dtype=jnp.float32)
    s = jnp.array([3.14], dtype=jnp.float32)

    r = scale_op(x, s, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])[0]
    print(f"Result matches: {bool(jnp.allclose(r, x * 3.14, atol=1e-5))}")

10. Summary#

In this tutorial we covered:

  1. @numba.njit – JIT-compile Python to native machine code. Kernel signature: kernel(input1, ..., output1, ...) – all NumPy arrays, no return values.

  2. @numba.njit(parallel=True) + numba.prange – Multi-threaded parallelism on CPU cores with zero additional code.

  3. numba_kernel(kernel, outs=...) – Wrap a Numba kernel as a JAX-callable via XLA FFI. Returns a function compatible with jax.jit.

  4. Multiple outputs – Pass a list of jax.ShapeDtypeStruct to outs to get multiple return arrays from a single kernel call.

  5. XLACustomKernel.def_numba_kernel – Register a kernel generator as the CPU backend of a multi-backend custom JAX primitive.

  6. Neuroscience applications – LIF dynamics and sparse CSR matrix-vector product implemented with parallel Numba, demonstrating realistic use cases.

Key Guidelines#

  • Cache the wrapped callable (do not call numba_kernel inside @jax.jit); create it once at definition time.

  • Use @njit(parallel=True) + prange for outer loops; keep inner loops sequential.

  • Prefer Numba on CPU for irregular / sparse access patterns; prefer GPU backends (Warp, Numba CUDA) for large-scale parallel workloads.

Next Steps#

  • Tutorial 6: Custom GPU operators with Warp

  • Tutorial 7: Custom GPU operators with Numba CUDA

References#