Custom GPU Operators with Numba CUDA#
This tutorial shows how to write custom GPU kernels using Numba CUDA and integrate them into the BrainEvent / JAX ecosystem.
Numba is a JIT compiler for Python that targets CUDA GPUs
via its numba.cuda subpackage. Kernels are written in Python, compiled at first call,
and run natively on the GPU. BrainEvent provides two functions that bridge Numba CUDA kernels
into JAX via XLA’s Foreign Function Interface (FFI):
numba_cuda_kernel– wraps a single@cuda.jitkernel with fixed launch configuration.numba_cuda_callable– wraps an arbitrary Python function that may launch multiple CUDA kernels, allocate temporary memory, and orchestrate multi-step GPU computation.
Contents#
Why Numba CUDA?
Installation and Imports
Writing Numba CUDA Kernels (
@cuda.jit)numba_cuda_kernel– Single-Kernel WrapperLaunch Configuration:
grid/blockvs.launch_dimsnumba_cuda_callable– Multi-Kernel WrapperRegistering with
XLACustomKernelNeuroscience Example: Parallel Spike Threshold Detection
Performance Tips
Summary
1. Why Numba CUDA?#
Feature |
Numba CUDA |
Warp |
Raw CUDA C++ |
|---|---|---|---|
Language |
Python ( |
Python-like |
C++ |
Low-level control |
Full thread/block/shared-mem |
Partial |
Full |
Shared memory |
Yes |
Yes |
Yes |
Atomic operations |
Yes |
Yes |
Yes |
Device-side allocation |
Yes ( |
Limited |
Yes |
Multi-kernel orchestration |
Yes (via |
No |
Yes |
Choose Numba CUDA when you need:
Fine-grained control over shared memory or thread synchronization
Multi-kernel pipelines with temporary device allocations
Familiar CUDA programming model in Python
Requirements:
NVIDIA GPU with CUDA
pip install numba+ CUDA toolkitJAX with GPU support (
pip install jax[cuda12])
2. Installation and Imports#
# Install if needed:
# !pip install numba -U
# !pip install brainevent[cuda12] -U
import jax
import jax.numpy as jnp
import numpy as np
import brainevent
from brainevent import XLACustomKernel, numba_cuda_kernel, numba_cuda_callable
print(f"JAX version : {jax.__version__}")
print(f"JAX backend : {jax.default_backend()}")
print(f"BrainEvent : {brainevent.__version__}")
try:
from numba import cuda
NUMBA_CUDA_AVAILABLE = cuda.is_available()
if NUMBA_CUDA_AVAILABLE:
import numba
print(f"Numba version : {numba.__version__}")
print(f"CUDA available : {NUMBA_CUDA_AVAILABLE}")
gpu = cuda.get_current_device()
print(f"GPU : {gpu.name}")
else:
print("Numba installed but CUDA device not found.")
except ImportError:
print("Numba not installed. Run: pip install numba")
NUMBA_CUDA_AVAILABLE = False
3. Writing Numba CUDA Kernels (@cuda.jit)#
Numba CUDA kernels follow standard CUDA programming conventions:
Decorated with
@cuda.jitEach thread identifies itself via
cuda.grid(ndim)(equivalent toblockIdx * blockDim + threadIdx)Arrays received as Numba device arrays (zero-copy from JAX GPU memory)
Results are written in-place into output arrays (no return values)
3.1 Basic Element-wise Kernel#
if NUMBA_CUDA_AVAILABLE:
@cuda.jit
def elementwise_relu_kernel(x, out):
"""Element-wise ReLU: out[i] = max(x[i], 0)."""
i = cuda.grid(1) # global thread index
if i < out.size: # bounds check
out[i] = max(x[i], 0.0)
@cuda.jit
def elementwise_sigmoid_kernel(x, out):
"""Element-wise sigmoid: out[i] = 1 / (1 + exp(-x[i]))."""
import math
i = cuda.grid(1)
if i < out.size:
out[i] = 1.0 / (1.0 + math.exp(-x[i]))
print("Kernels defined:", elementwise_relu_kernel, elementwise_sigmoid_kernel)
4. numba_cuda_kernel – Single-Kernel Wrapper#
numba_cuda_kernel wraps a single @cuda.jit kernel so it can be called with
JAX GPU arrays. The kernel receives Numba CUDA device arrays (zero-copy from JAX
device memory) and writes results into the output buffer.
Function signature:
numba_cuda_kernel(
kernel, # @cuda.jit decorated function
outs, # jax.ShapeDtypeStruct or list thereof
*,
grid=None, block=None, # explicit CUDA launch config
launch_dims=None, # OR total threads (auto grid/block)
threads_per_block=256, # only used with launch_dims
shared_mem=0, # dynamic shared memory bytes
) -> callable
The kernel function signature must be:
kernel(input1, input2, ..., output1, output2, ...)
Inputs first, then outputs – all as Numba device arrays.
if NUMBA_CUDA_AVAILABLE:
N = 1024
x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)
# Wrap the ReLU kernel
relu_fn = numba_cuda_kernel(
elementwise_relu_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
grid=4,
block=256,
)
result = relu_fn(x)
expected = jnp.maximum(x, 0.0)
print("ReLU max error:", float(jnp.max(jnp.abs(result - expected))))
# Wrap the sigmoid kernel using launch_dims (auto grid/block)
sigmoid_fn = numba_cuda_kernel(
elementwise_sigmoid_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
launch_dims=N, # launch exactly N threads
threads_per_block=128,
)
result_sig = sigmoid_fn(x)
expected_sig = 1.0 / (1.0 + jnp.exp(-x))
print("Sigmoid max error:", float(jnp.max(jnp.abs(result_sig - expected_sig))))
4.1 Multiple Outputs#
if NUMBA_CUDA_AVAILABLE:
@cuda.jit
def split_kernel(x, pos_out, neg_out):
"""Split positive and negative parts of x."""
i = cuda.grid(1)
if i < x.size:
v = x[i]
pos_out[i] = max(v, 0.0)
neg_out[i] = min(v, 0.0)
N = 512
x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)
split_fn = numba_cuda_kernel(
split_kernel,
outs=[
jax.ShapeDtypeStruct((N,), jnp.float32), # pos_out
jax.ShapeDtypeStruct((N,), jnp.float32), # neg_out
],
launch_dims=N,
)
pos, neg = split_fn(x)
print("pos_out[:5]:", pos[:5]) # max(x, 0)
print("neg_out[:5]:", neg[:5]) # min(x, 0)
print("pos + neg == x:", bool(jnp.allclose(pos + neg, x)))
4.2 JIT Compatibility#
if NUMBA_CUDA_AVAILABLE:
@cuda.jit
def add_kernel(x, y, out):
i = cuda.grid(1)
if i < out.size:
out[i] = x[i] + y[i]
N = 256
add_fn = numba_cuda_kernel(
add_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
launch_dims=N,
)
@jax.jit
def jitted_add(a, b):
return add_fn(a, b)
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.ones(N, dtype=jnp.float32) * 2.0
r = jitted_add(a, b)
print("JIT add max error:", float(jnp.max(jnp.abs(r - (a + b)))))
# Call multiple times (JIT is amortized after first call)
for _ in range(5):
r = jitted_add(a, b)
print("Multiple JIT calls OK:", bool(jnp.allclose(r, a + b)))
5. Launch Configuration: grid / block vs. launch_dims#
CUDA kernels need a grid/block decomposition specifying how many threads to launch.
Option A – explicit grid and block:
numba_cuda_kernel(kernel, outs=..., grid=8, block=128)
# launches 8 blocks × 128 threads = 1024 total threads
Option B – launch_dims (auto-compute):
numba_cuda_kernel(kernel, outs=..., launch_dims=1024, threads_per_block=256)
# auto: block=256, grid=ceil(1024/256)=4
2D / 3D launches:
# 2D: launch M×N threads
numba_cuda_kernel(kernel, outs=..., launch_dims=(M, N))
# auto: block=(16,16), grid=(ceil(M/16), ceil(N/16))
if NUMBA_CUDA_AVAILABLE:
@cuda.jit
def matmul_element_kernel(A, B, C):
"""C[i,j] = A[i,j] * B[i,j] (element-wise, 2D grid)."""
i, j = cuda.grid(2)
if i < C.shape[0] and j < C.shape[1]:
C[i, j] = A[i, j] * B[i, j]
M, N = 64, 64
A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N)
B = jnp.ones((M, N), dtype=jnp.float32) * 2.0
hadamard_fn = numba_cuda_kernel(
matmul_element_kernel,
outs=jax.ShapeDtypeStruct((M, N), jnp.float32),
launch_dims=(M, N), # 2D launch
)
C = hadamard_fn(A, B)
print("2D kernel max error:", float(jnp.max(jnp.abs(C - A * B))))
6. numba_cuda_callable – Multi-Kernel Wrapper#
Sometimes one kernel is not enough. For example, a reduction may need two passes, or a pipeline may require a temporary device buffer between stages.
numba_cuda_callable wraps an arbitrary Python function that can:
Launch multiple
@cuda.jitkernelsAllocate temporary device memory with
cuda.device_arrayUse the XLA-managed CUDA stream (passed as the last argument)
Required function signature:
def my_func(input1, input2, ..., output1, output2, ..., stream):
# input* and output* are Numba CUDA device arrays
# stream is a Numba CUDA stream from XLA
...
if NUMBA_CUDA_AVAILABLE:
# ---- Kernel 1: element-wise square ----
@cuda.jit
def square_kernel(x, temp):
i = cuda.grid(1)
if i < temp.size:
temp[i] = x[i] * x[i]
# ---- Kernel 2: element-wise square root ----
@cuda.jit
def sqrt_kernel(temp, out):
import math
i = cuda.grid(1)
if i < out.size:
out[i] = math.sqrt(temp[i])
# ---- Multi-kernel callable: |x| = sqrt(x^2) ----
def abs_via_two_kernels(x, out, stream):
"""
Compute |x| using two kernels and a temporary buffer.
Demonstrates multi-kernel pipeline with device allocation.
"""
n = x.shape[0]
threads = 256
blocks = (n + threads - 1) // threads
# Temporary buffer on the GPU (freed after this function returns)
temp = cuda.device_array(n, dtype=np.float32)
# Launch both kernels on the XLA-managed stream
square_kernel[blocks, threads, stream](x, temp)
sqrt_kernel [blocks, threads, stream](temp, out)
N = 512
x = jnp.linspace(-5.0, 5.0, N, dtype=jnp.float32)
abs_fn = numba_cuda_callable(
abs_via_two_kernels,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
)
result = abs_fn(x)
expected = jnp.abs(x)
print("Multi-kernel |x| max error:", float(jnp.max(jnp.abs(result - expected))))
print("First 5 values :", result[:5])
print("Expected |x| :", expected[:5])
7. Registering with XLACustomKernel#
For production use, register your Numba CUDA kernel as a backend of an
XLACustomKernel primitive. This integrates the kernel into JAX’s lowering
pipeline and allows mixing with other backends (e.g., a Numba CPU fallback).
The kernel generator pattern is the same as for Warp (see Tutorial 6):
it is a Python callable that receives keyword arguments (forwarded from
primitive.bind) and returns a concrete kernel function.
if NUMBA_CUDA_AVAILABLE:
# -----------------------------------------------------------------------
# Kernel generator for element-wise leaky ReLU:
# out[i] = x[i] if x[i] > 0 else alpha * x[i]
# 'alpha' is passed at trace time via kwargs.
# -----------------------------------------------------------------------
def leaky_relu_numba_cuda_generator(**kwargs):
out_info = kwargs['outs'][0]
n = out_info.shape[0]
alpha = float(kwargs.get('alpha', 0.01))
@cuda.jit
def leaky_relu_kern(x, out):
i = cuda.grid(1)
if i < out.size:
v = x[i]
out[i] = v if v > 0.0 else alpha * v
def kernel(x):
return numba_cuda_kernel(
leaky_relu_kern,
outs=out_info,
launch_dims=n,
)(x)
return kernel
# Register the primitive
leaky_relu_op = XLACustomKernel('tutorial_numba_cuda_leaky_relu')
leaky_relu_op.def_numba_cuda_kernel(leaky_relu_numba_cuda_generator)
print("Registered:", leaky_relu_op._kernels)
if NUMBA_CUDA_AVAILABLE:
N = 256
x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)
@jax.jit
def jitted_leaky_relu(x, alpha=0.1):
return leaky_relu_op(
x,
outs=[jax.ShapeDtypeStruct(x.shape, x.dtype)],
alpha=alpha,
)[0]
r = jitted_leaky_relu(x, alpha=0.1)
expected = jnp.where(x > 0, x, 0.1 * x)
print("Leaky ReLU max error:", float(jnp.max(jnp.abs(r - expected))))
print("Values around 0 :", r[N//2 - 3 : N//2 + 3])
8. Neuroscience Example: Parallel Spike Threshold Detection#
A common operation in spiking neural networks: given membrane potentials V and a
threshold V_th, detect which neurons fire and reset their potentials in-place.
We implement this as two Numba CUDA kernels fused via numba_cuda_callable:
Detect spikes:
spikes[i] = (V[i] >= V_th)Reset potentials:
V_reset[i] = spikes[i] ? V_rest : V[i]
if NUMBA_CUDA_AVAILABLE:
@cuda.jit
def detect_spikes_kernel(V, V_th, spikes):
"""spikes[i] = 1 if V[i] >= V_th[0] else 0."""
i = cuda.grid(1)
if i < V.size:
spikes[i] = 1 if V[i] >= V_th[0] else 0
@cuda.jit
def reset_potential_kernel(V, spikes, V_rest, V_out):
"""V_out[i] = V_rest[0] if spikes[i] else V[i]."""
i = cuda.grid(1)
if i < V.size:
V_out[i] = V_rest[0] if spikes[i] else V[i]
def lif_step(V, V_th, V_rest, spikes_out, V_out, stream):
"""One LIF step: detect spikes and reset membrane potential."""
n = V.shape[0]
threads = 256
blocks = (n + threads - 1) // threads
detect_spikes_kernel[blocks, threads, stream](V, V_th, spikes_out)
reset_potential_kernel[blocks, threads, stream](V, spikes_out, V_rest, V_out)
print("LIF step function defined.")
if NUMBA_CUDA_AVAILABLE:
N_NEURONS = 10_000
rng = np.random.default_rng(0)
V = jnp.array(rng.uniform(-75.0, -50.0, N_NEURONS).astype(np.float32))
V_th = jnp.array([-55.0], dtype=jnp.float32) # threshold (mV)
V_rest = jnp.array([-70.0], dtype=jnp.float32) # reset potential (mV)
lif_fn = numba_cuda_callable(
lif_step,
outs=[
jax.ShapeDtypeStruct((N_NEURONS,), jnp.int32), # spikes
jax.ShapeDtypeStruct((N_NEURONS,), jnp.float32), # V_out
],
)
spikes, V_out = lif_fn(V, V_th, V_rest)
# Verify
expected_spikes = (V >= V_th[0]).astype(jnp.int32)
expected_V_out = jnp.where(expected_spikes, V_rest[0], V)
print(f"Neurons: {N_NEURONS}")
print(f"Spikes detected: {int(spikes.sum())} / {N_NEURONS} ({100*float(spikes.mean()):.1f}%)")
print(f"Spike detection error: {int(jnp.sum(spikes != expected_spikes))}")
print(f"V_out max error: {float(jnp.max(jnp.abs(V_out - expected_V_out))):.6f} mV")
if NUMBA_CUDA_AVAILABLE:
import time
@jax.jit
def numba_lif_step(V, V_th, V_rest):
return lif_fn(V, V_th, V_rest)
@jax.jit
def jax_lif_step(V, V_th, V_rest):
spikes = (V >= V_th[0]).astype(jnp.int32)
V_out = jnp.where(spikes, V_rest[0], V)
return spikes, V_out
# Warm up
jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
jax.block_until_ready(jax_lif_step(V, V_th, V_rest))
N_TRIALS = 500
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
numba_time = (time.time() - t0) / N_TRIALS * 1000
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(jax_lif_step(V, V_th, V_rest))
jax_time = (time.time() - t0) / N_TRIALS * 1000
print(f"Numba CUDA LIF step : {numba_time:.3f} ms")
print(f"JAX native LIF step : {jax_time:.3f} ms")
9. Performance Tips#
9.1 Choose the right launch configuration#
A warp is 32 threads; use block sizes that are multiples of 32 (128, 256, 512).
Too few threads per block wastes warp slots; too many limits occupancy.
Use
launch_dimsfor simple 1-D problems; specify explicitgrid/blockfor fine control.
9.2 Minimize thread divergence#
Threads in the same warp execute in lock-step. Conditional branches that differ between threads (divergence) serialize execution. Where possible, arrange data so threads in the same warp take the same branch.
9.4 Cache the wrapped callable#
Each call to numba_cuda_kernel / numba_cuda_callable registers a new FFI target.
Create the wrapped callable once at module level and reuse it.
if NUMBA_CUDA_AVAILABLE:
# Good pattern: create the callable once and reuse
@cuda.jit
def exp_decay_kernel(x, decay, out):
import math
i = cuda.grid(1)
if i < out.size:
out[i] = x[i] * math.exp(-decay[0])
N = 1024
# Create once at definition time
_exp_decay_fn = numba_cuda_kernel(
exp_decay_kernel,
outs=jax.ShapeDtypeStruct((N,), jnp.float32),
launch_dims=N,
)
@jax.jit
def apply_exp_decay(x, decay):
return _exp_decay_fn(x, decay)
x = jnp.ones(N, dtype=jnp.float32)
decay = jnp.array([0.1], dtype=jnp.float32)
r = apply_exp_decay(x, decay)
print("Exp decay result:", float(r[0]), "| Expected:", float(np.exp(-0.1)))
10. Summary#
In this tutorial we covered:
@cuda.jit– Write GPU kernels in Python; usecuda.grid(ndim)for thread indices and bounds-check withif i < size.numba_cuda_kernel– Single-kernel JAX wrapper. Specify launch config via(grid, block)for explicit control orlaunch_dimsfor automatic decomposition. Supports 1-D, 2-D, and 3-D launches.numba_cuda_callable– Multi-kernel JAX wrapper. Your Python function receives Numba device arrays and the XLA-managed CUDA stream; it can launch multiple kernels and allocate temporary device memory.XLACustomKernel.def_numba_cuda_kernel– Register a kernel generator as the GPU backend of a multi-backend custom JAX primitive.Neuroscience application – LIF spike detection and reset implemented as a two-kernel callable, showing the key pattern for fused GPU operations.
Next Steps#
Tutorial 6: Custom GPU operators with Warp
Tutorial 8: Custom CPU operators with Numba (
@numba.njit)