Custom CPU Operators with Numba#
This tutorial shows how to write high-performance CPU kernels using Numba’s @njit
decorator and integrate them into the BrainEvent / JAX ecosystem.
Numba compiles Python functions to native machine code
via LLVM, achieving speeds comparable to C/Fortran. BrainEvent’s numba_kernel function
bridges Numba JIT-compiled functions into JAX via XLA’s Foreign Function Interface (FFI),
so your Numba kernels become first-class JAX operations compatible with jax.jit, jax.vmap,
and other transforms.
Contents#
Why Numba on CPU?
Installation and Imports
Writing Numba JIT Kernels (
@numba.njit)numba_kernel– Wrapping for JAXParallel Kernels with
numba.prangeMultiple Inputs and Outputs
Registering with
XLACustomKernelNeuroscience Example: Sparse CSR × Float-Vector Multiplication
Combining Numba CPU and Warp/Numba-CUDA Backends
Summary
1. Why Numba on CPU?#
JAX runs on CPU, GPU, and TPU but some algorithms do not map well to the GPU’s massively-parallel execution model:
Sparse / irregular access patterns – random memory accesses serialize on GPU
Sequential algorithms – recurrences that depend on previous iterations
Small to medium problem sizes – GPU overhead dominates for small arrays
CPU-only environments – laptops, CI servers, edge devices
Property |
JAX native (CPU) |
Numba ( |
C extension |
|---|---|---|---|
JIT speed |
XLA (fast) |
LLVM (fast) |
Compiled ahead of time |
Python overhead |
Yes |
Eliminated |
Eliminated |
Parallelism |
Limited |
|
pthread / OpenMP |
Custom loop structure |
No |
Yes |
Yes |
Write in Python |
Yes |
Yes |
No |
Numba @njit lets you write the inner loop in Python while achieving native
performance, and brainevent.numba_kernel makes the result a proper JAX primitive.
Requirements: pip install numba (no GPU needed)
2. Installation and Imports#
# Install if needed:
# !pip install numba -U
# !pip install brainevent -U
import jax
import jax.numpy as jnp
import numpy as np
import brainevent
from brainevent import XLACustomKernel, numba_kernel
print(f"JAX version : {jax.__version__}")
print(f"JAX backend : {jax.default_backend()}")
print(f"BrainEvent : {brainevent.__version__}")
try:
import numba
print(f"Numba version : {numba.__version__}")
NUMBA_AVAILABLE = True
except ImportError:
print("Numba not installed. Run: pip install numba")
NUMBA_AVAILABLE = False
3. Writing Numba JIT Kernels (@numba.njit)#
Rules for Numba CPU kernels used with numba_kernel:
Decorate with
@numba.njit(or@numba.njit(parallel=True)for parallelism)Function signature:
kernel(input1, input2, ..., output1, output2, ...)– inputs first, then outputs; all as NumPy arrays (zero-copy from JAX)Write results into output arrays – no return values
Standard Python math, NumPy slicing, and
forloops all work
3.1 Simple Element-wise Kernels#
if NUMBA_AVAILABLE:
@numba.njit
def add_kernel(x, y, out):
"""out[i] = x[i] + y[i]"""
for i in range(out.size):
out[i] = x[i] + y[i]
@numba.njit
def relu_kernel(x, out):
"""out[i] = max(x[i], 0.0)"""
for i in range(out.size):
v = x[i]
out[i] = v if v > 0.0 else 0.0
@numba.njit
def matvec_kernel(A, x, out):
"""Dense matrix-vector product: out = A @ x"""
rows, cols = A.shape
for i in range(rows):
total = A.dtype.type(0)
for j in range(cols):
total += A[i, j] * x[j]
out[i] = total
print("Numba kernels defined:", add_kernel, relu_kernel, matvec_kernel)
3.2 Reduction Kernels#
if NUMBA_AVAILABLE:
@numba.njit
def sum_kernel(x, out):
"""out[0] = sum(x)."""
total = x.dtype.type(0)
for i in range(x.size):
total += x[i]
out[0] = total
@numba.njit
def max_kernel(x, out):
"""out[0] = max(x)."""
m = x[0]
for i in range(1, x.size):
if x[i] > m:
m = x[i]
out[0] = m
@numba.njit
def running_stats_kernel(x, mean_out, std_out):
"""Compute mean and std in a single pass."""
n = x.size
s = x.dtype.type(0)
for i in range(n):
s += x[i]
mean = s / n
var = x.dtype.type(0)
for i in range(n):
d = x[i] - mean
var += d * d
mean_out[0] = mean
std_out[0] = (var / n) ** 0.5
print("Reduction kernels defined.")
4. numba_kernel – Wrapping for JAX#
numba_kernel wraps a Numba CPU kernel so it can be called with JAX CPU arrays
via XLA’s typed FFI protocol.
Signature:
numba_kernel(
kernel, # @numba.njit function
outs, # jax.ShapeDtypeStruct or list thereof
*,
vmap_method=None,
input_output_aliases=None,
) -> callable
The returned callable accepts JAX arrays as inputs and returns JAX arrays as outputs.
It is compatible with jax.jit.
if NUMBA_AVAILABLE:
N = 512
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.ones(N, dtype=jnp.float32) * 3.0
# Create the JAX-callable wrapper
add_fn = numba_kernel(
add_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
)
result = add_fn(a, b)
# numba_kernel returns a tuple; unwrap if needed
result = result[0] if isinstance(result, tuple) else result
expected = a + b
print("Add max error :", float(jnp.max(jnp.abs(result - expected))))
# ---- ReLU ----
x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)
relu_fn = numba_kernel(
relu_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
)
r = relu_fn(x)
r = r[0] if isinstance(r, tuple) else r
print("ReLU max error :", float(jnp.max(jnp.abs(r - jnp.maximum(x, 0.0)))))
if NUMBA_AVAILABLE:
# ---- Reduction ----
N = 10_000
x = jnp.arange(N, dtype=jnp.float32)
sum_fn = numba_kernel(
sum_kernel,
outs=jax.ShapeDtypeStruct((1,), jnp.float32),
)
s = sum_fn(x)
s = s[0] if isinstance(s, tuple) else s
print(f"Sum: {float(s[0]):.1f} | Expected: {float(jnp.sum(x)):.1f}")
# ---- Multiple outputs (mean and std in one pass) ----
stats_fn = numba_kernel(
running_stats_kernel,
outs=[
jax.ShapeDtypeStruct((1,), jnp.float32), # mean
jax.ShapeDtypeStruct((1,), jnp.float32), # std
],
)
mean_val, std_val = stats_fn(x)
print(f"Mean: {float(mean_val[0]):.2f} | Std: {float(std_val[0]):.2f}")
print(f"jnp.mean: {float(jnp.mean(x)):.2f} | jnp.std: {float(jnp.std(x)):.2f}")
4.1 JIT Compatibility#
if NUMBA_AVAILABLE:
N = 128
add_fn_cached = numba_kernel(
add_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
)
@jax.jit
def jitted_pipeline(a, b):
# Mix Numba kernel with standard JAX operations
temp = add_fn_cached(a, b)
temp = temp[0] if isinstance(temp, tuple) else temp
return jnp.sin(temp) * jnp.sqrt(jnp.abs(temp) + 1.0)
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.ones(N, dtype=jnp.float32)
r1 = jitted_pipeline(a, b)
r2 = jitted_pipeline(a * 2, b * 0.5) # second call reuses compiled code
print("JIT pipeline output shape:", r1.shape)
print("First 5 values :", r1[:5])
5. Parallel Kernels with numba.prange#
Adding parallel=True to @numba.njit and replacing range with numba.prange
enables automatic parallelization across CPU cores using threading.
This is the easiest way to exploit multi-core CPUs without writing thread management code.
if NUMBA_AVAILABLE:
@numba.njit(parallel=True)
def parallel_add_kernel(x, y, out):
"""Parallel element-wise add using prange."""
for i in numba.prange(out.size): # parallelized loop
out[i] = x[i] + y[i]
@numba.njit(parallel=True)
def parallel_matvec_kernel(A, x, out):
"""Parallel matrix-vector product: each row computed by a separate thread."""
rows, cols = A.shape
for i in numba.prange(rows): # parallelize over rows
total = A.dtype.type(0)
for j in range(cols): # inner loop stays sequential
total += A[i, j] * x[j]
out[i] = total
@numba.njit(parallel=True)
def parallel_exp_decay_kernel(trace, spikes, tau_inv, out):
"""
Exponential trace update used in STDP:
out[i] = trace[i] * exp(-tau_inv) + spikes[i]
"""
import math
decay = math.exp(-tau_inv[0])
for i in numba.prange(out.size):
out[i] = trace[i] * decay + spikes[i]
print("Parallel kernels defined.")
if NUMBA_AVAILABLE:
import time
N = 1_000_000
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.ones(N, dtype=jnp.float32)
serial_fn = numba_kernel(add_kernel, outs=jax.ShapeDtypeStruct((N,), jnp.float32))
parallel_fn = numba_kernel(parallel_add_kernel, outs=jax.ShapeDtypeStruct((N,), jnp.float32))
# Warm up
jax.block_until_ready(serial_fn(a, b))
jax.block_until_ready(parallel_fn(a, b))
N_TRIALS = 20
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(serial_fn(a, b))
serial_time = (time.time() - t0) / N_TRIALS * 1000
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(parallel_fn(a, b))
parallel_time = (time.time() - t0) / N_TRIALS * 1000
import os
n_cores = os.cpu_count()
print(f"N = {N:,} | CPU cores: {n_cores}")
print(f"Serial : {serial_time:.2f} ms")
print(f"Parallel : {parallel_time:.2f} ms")
print(f"Speedup : {serial_time / parallel_time:.2f}x")
6. Multiple Inputs and Outputs#
Numba kernels can take any number of inputs and outputs.
The outs argument to numba_kernel mirrors the output buffers:
a single ShapeDtypeStruct for one output, a list for multiple.
if NUMBA_AVAILABLE:
@numba.njit(parallel=True)
def lif_dynamics_kernel(
V, # membrane potentials (N,)
I_ext, # external current (N,)
tau_inv, # 1/tau_m (1,)
V_th, # threshold (1,)
V_rest, # reset potential (1,)
dt, # time step (1,)
V_out, # updated potentials (N,) – output
spikes, # spike vector (N,) – output
):
"""
One Euler step of leaky integrate-and-fire dynamics:
V_out[i] = V[i] + dt * (-(V[i] - V_rest) * tau_inv + I_ext[i])
then threshold and reset.
"""
th = V_th[0]
vr = V_rest[0]
ti = tau_inv[0]
step = dt[0]
for i in numba.prange(V.size):
v_new = V[i] + step * (-(V[i] - vr) * ti + I_ext[i])
if v_new >= th:
spikes[i] = 1
V_out[i] = vr
else:
spikes[i] = 0
V_out[i] = v_new
print("LIF dynamics kernel defined.")
if NUMBA_AVAILABLE:
N = 5_000
rng = np.random.default_rng(0)
V = jnp.array(rng.uniform(-75.0, -50.0, N).astype(np.float32))
I_ext = jnp.array(rng.uniform(0.0, 5.0, N).astype(np.float32))
tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32) # tau_m = 20 ms
V_th = jnp.array([-55.0], dtype=jnp.float32)
V_rest = jnp.array([-70.0], dtype=jnp.float32)
dt = jnp.array([0.1], dtype=jnp.float32) # dt = 0.1 ms
lif_fn = numba_kernel(
lif_dynamics_kernel,
outs=[
jax.ShapeDtypeStruct((N,), jnp.float32), # V_out
jax.ShapeDtypeStruct((N,), jnp.int32), # spikes
],
)
V_new, spikes = lif_fn(V, I_ext, tau_inv, V_th, V_rest, dt)
print(f"Neurons: {N}")
print(f"Spikes : {int(spikes.sum())} ({100*float(spikes.mean()):.1f}%)")
print(f"V range: [{float(V_new.min()):.2f}, {float(V_new.max()):.2f}] mV")
# Verify against JAX reference
V_ref = V + dt[0] * (-(V - V_rest[0]) * tau_inv[0] + I_ext)
spk_ref = (V_ref >= V_th[0]).astype(jnp.int32)
V_ref = jnp.where(spk_ref, V_rest[0], V_ref)
print(f"V max error : {float(jnp.max(jnp.abs(V_new - V_ref))):.6f} mV")
print(f"Spike mismatch : {int(jnp.sum(spikes != spk_ref))}")
7. Registering with XLACustomKernel#
For production use, embed your Numba kernel inside a kernel generator and
register it with XLACustomKernel. The generator receives shape/dtype
information forwarded from primitive.bind and returns the concrete callable.
if NUMBA_AVAILABLE:
# -----------------------------------------------------------------------
# Kernel generator: exponential trace update (STDP)
# out[i] = trace[i] * decay + spikes[i]
# -----------------------------------------------------------------------
def exp_trace_numba_generator(**kwargs):
out_info = kwargs['outs'][0]
n = out_info.shape[0]
@numba.njit(parallel=True)
def trace_kern(trace, spikes, tau_inv, out):
import math
decay = math.exp(-tau_inv[0])
for i in numba.prange(n):
out[i] = trace[i] * decay + spikes[i]
def kernel(trace, spikes, tau_inv):
result = numba_kernel(
trace_kern,
outs=out_info,
)(trace, spikes, tau_inv)
return result if not isinstance(result, tuple) else result
return kernel
# Register the primitive
trace_op = XLACustomKernel('tutorial_numba_exp_trace')
trace_op.def_numba_kernel(exp_trace_numba_generator)
print("Registered backends:", list(trace_op._kernels.keys()))
if NUMBA_AVAILABLE:
N = 1000
trace = jnp.zeros(N, dtype=jnp.float32)
spikes = jnp.array(np.random.default_rng(1).random(N) < 0.1,
dtype=jnp.float32)
tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32) # tau = 20 ms
out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)
@jax.jit
def update_trace(trace, spikes, tau_inv):
return trace_op(
trace, spikes, tau_inv,
outs=[out_spec],
)[0]
# Simulate 100 time steps of trace dynamics
import math
decay = math.exp(-float(tau_inv[0]))
trace_history = []
for step in range(100):
spikes = jnp.array(
np.random.default_rng(step).random(N) < 0.05,
dtype=jnp.float32
)
trace = update_trace(trace, spikes, tau_inv)
trace_history.append(float(trace.mean()))
print(f"Trace stats after 100 steps:")
print(f" Mean : {float(trace.mean()):.4f}")
print(f" Max : {float(trace.max()):.4f}")
print(f" Steady-state (theory): {0.05 / (1 - decay):.4f}")
8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication#
A core operation in neural network simulation: given a CSR weight matrix and a float input vector, compute the matrix-vector product.
This is naturally sequential per output neuron (row of CSR), making it a good fit for parallel Numba on CPU.
if NUMBA_AVAILABLE:
@numba.njit(parallel=True)
def csr_matvec_numba(
data, # CSR non-zero values (nnz,)
indices, # CSR column indices (nnz,)
indptr, # CSR row pointers (n_rows+1,)
x, # input vector (n_cols,)
out, # output vector (n_rows,)
):
"""
Sparse matrix-vector product (CSR format).
Each row is processed by one thread (parallel over rows).
"""
n_rows = indptr.size - 1
for i in numba.prange(n_rows):
total = out.dtype.type(0)
for k in range(indptr[i], indptr[i + 1]):
total += data[k] * x[indices[k]]
out[i] = total
def csr_mv_numba_generator(**kwargs):
out_info = kwargs['outs'][0]
def kernel(data, indices, indptr, x):
result = numba_kernel(
csr_matvec_numba,
outs=out_info,
)(data, indices, indptr, x)
return result if not isinstance(result, tuple) else result
return kernel
csr_mv_op = XLACustomKernel('tutorial_numba_csr_matvec')
csr_mv_op.def_numba_kernel(csr_mv_numba_generator)
print("CSR MV operator registered.")
if NUMBA_AVAILABLE:
import scipy.sparse as sp
N_PRE = 2000
N_POST = 1000
PROB = 0.05
rng = np.random.default_rng(42)
dense = (rng.random((N_POST, N_PRE)) < PROB).astype(np.float32)
dense *= rng.uniform(0.01, 0.5, dense.shape).astype(np.float32)
csr = sp.csr_matrix(dense)
data = jnp.array(csr.data, dtype=jnp.float32)
indices = jnp.array(csr.indices, dtype=jnp.int32)
indptr = jnp.array(csr.indptr, dtype=jnp.int32)
x = jnp.array(rng.random(N_PRE).astype(np.float32))
out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)
result = csr_mv_op(
data, indices, indptr, x,
outs=[out_spec],
)[0]
expected = jnp.array(dense) @ x
print(f"Network: {N_PRE} pre -> {N_POST} post (nnz={csr.nnz})")
print(f"Max error vs dense: {float(jnp.max(jnp.abs(result - expected))):.6f}")
if NUMBA_AVAILABLE:
import time
@jax.jit
def numba_csr_mv(data, indices, indptr, x):
return csr_mv_op(
data, indices, indptr, x,
outs=[out_spec],
)[0]
@jax.jit
def jax_dense_mv(A, x):
return A @ x
A_jnp = jnp.array(dense)
# Warm up
jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
jax.block_until_ready(jax_dense_mv(A_jnp, x))
N_TRIALS = 50
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
numba_time = (time.time() - t0) / N_TRIALS * 1000
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(jax_dense_mv(A_jnp, x))
jax_time = (time.time() - t0) / N_TRIALS * 1000
print(f"Numba CSR MV : {numba_time:.2f} ms")
print(f"JAX dense MV : {jax_time:.2f} ms")
print(f"Speedup : {jax_time / numba_time:.2f}x (sparsity: {1 - csr.nnz/(N_PRE*N_POST):.0%})")
9. Combining Numba CPU and GPU Backends#
The same XLACustomKernel primitive can have both a Numba CPU backend and a
GPU backend (Warp or Numba CUDA). JAX automatically dispatches to the correct
backend based on the device where the arrays live.
try:
import warp
from warp.jax_experimental import jax_kernel as warp_jax_kernel
from brainevent import jaxinfo_to_warpinfo
warp.config.quiet = True
WARP_AVAILABLE = True
except ImportError:
WARP_AVAILABLE = False
if NUMBA_AVAILABLE:
# CPU backend (already shown above)
@numba.njit(parallel=True)
def scale_numba(x, s, out):
for i in numba.prange(out.size):
out[i] = x[i] * s[0]
def scale_numba_generator(**kwargs):
out_info = kwargs['outs'][0]
def kernel(x, s):
r = numba_kernel(scale_numba, outs=out_info)(x, s)
return r if not isinstance(r, tuple) else r
return kernel
scale_op = XLACustomKernel('tutorial_multi_backend_scale')
scale_op.def_numba_kernel(scale_numba_generator) # CPU backend
if WARP_AVAILABLE:
def scale_warp_generator(**kwargs):
out_info = kwargs['outs'][0]
n = out_info.shape[0]
t = jaxinfo_to_warpinfo(out_info)
s_type = warp.array(dtype=jaxinfo_to_warpinfo(out_info).dtype, ndim=1)
@warp.kernel
def kern(x: t, s: s_type, out: t):
i = warp.tid()
out[i] = x[i] * s[0]
def kernel(x, s):
fn = warp_jax_kernel(kern, launch_dims=[n], num_outputs=1,
output_dims={'out': (n,)})
return fn(x, s)
return kernel
scale_op.def_warp_kernel(scale_warp_generator) # GPU backend
print("Multi-backend scale op registered.")
print("Backends:", {p: list(b.keys()) for p, b in scale_op._kernels.items()})
# Use it
N = 256
x = jnp.arange(N, dtype=jnp.float32)
s = jnp.array([3.14], dtype=jnp.float32)
r = scale_op(x, s, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])[0]
print(f"Result matches: {bool(jnp.allclose(r, x * 3.14, atol=1e-5))}")
10. Summary#
In this tutorial we covered:
@numba.njit– JIT-compile Python to native machine code. Kernel signature:kernel(input1, ..., output1, ...)– all NumPy arrays, no return values.@numba.njit(parallel=True)+numba.prange– Multi-threaded parallelism on CPU cores with zero additional code.numba_kernel(kernel, outs=...)– Wrap a Numba kernel as a JAX-callable via XLA FFI. Returns a function compatible withjax.jit.Multiple outputs – Pass a list of
jax.ShapeDtypeStructtooutsto get multiple return arrays from a single kernel call.XLACustomKernel.def_numba_kernel– Register a kernel generator as the CPU backend of a multi-backend custom JAX primitive.Neuroscience applications – LIF dynamics and sparse CSR matrix-vector product implemented with parallel Numba, demonstrating realistic use cases.
Key Guidelines#
Cache the wrapped callable (do not call
numba_kernelinside@jax.jit); create it once at definition time.Use
@njit(parallel=True)+prangefor outer loops; keep inner loops sequential.Prefer Numba on CPU for irregular / sparse access patterns; prefer GPU backends (Warp, Numba CUDA) for large-scale parallel workloads.
Next Steps#
Tutorial 6: Custom GPU operators with Warp
Tutorial 7: Custom GPU operators with Numba CUDA