Custom GPU Operators with Numba CUDA#

This tutorial shows how to write custom GPU kernels using Numba CUDA and integrate them into the BrainEvent / JAX ecosystem.

Numba is a JIT compiler for Python that targets CUDA GPUs via its numba.cuda subpackage. Kernels are written in Python, compiled at first call, and run natively on the GPU. BrainEvent provides two functions that bridge Numba CUDA kernels into JAX via XLA’s Foreign Function Interface (FFI):

  • numba_cuda_kernel – wraps a single @cuda.jit kernel with fixed launch configuration.

  • numba_cuda_callable – wraps an arbitrary Python function that may launch multiple CUDA kernels, allocate temporary memory, and orchestrate multi-step GPU computation.

Contents#

  1. Why Numba CUDA?

  2. Installation and Imports

  3. Writing Numba CUDA Kernels (@cuda.jit)

  4. numba_cuda_kernel – Single-Kernel Wrapper

  5. Launch Configuration: grid / block vs. launch_dims

  6. numba_cuda_callable – Multi-Kernel Wrapper

  7. Registering with XLACustomKernel

  8. Neuroscience Example: Parallel Spike Threshold Detection

  9. Performance Tips

  10. Summary

1. Why Numba CUDA?#

Feature

Numba CUDA

Warp

Raw CUDA C++

Language

Python (@cuda.jit)

Python-like

C++

Low-level control

Full thread/block/shared-mem

Partial

Full

Shared memory

Yes

Yes

Yes

Atomic operations

Yes

Yes

Yes

Device-side allocation

Yes (cuda.device_array)

Limited

Yes

Multi-kernel orchestration

Yes (via numba_cuda_callable)

No

Yes

Choose Numba CUDA when you need:

  • Fine-grained control over shared memory or thread synchronization

  • Multi-kernel pipelines with temporary device allocations

  • Familiar CUDA programming model in Python

Requirements:

  • NVIDIA GPU with CUDA

  • pip install numba + CUDA toolkit

  • JAX with GPU support (pip install jax[cuda12])

2. Installation and Imports#

# Install if needed:
# !pip install numba -U
# !pip install brainevent[cuda12] -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, numba_cuda_kernel, numba_cuda_callable

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    from numba import cuda
    NUMBA_CUDA_AVAILABLE = cuda.is_available()
    if NUMBA_CUDA_AVAILABLE:
        import numba
        print(f"Numba version  : {numba.__version__}")
        print(f"CUDA available : {NUMBA_CUDA_AVAILABLE}")
        gpu = cuda.get_current_device()
        print(f"GPU            : {gpu.name}")
    else:
        print("Numba installed but CUDA device not found.")
except ImportError:
    print("Numba not installed. Run: pip install numba")
    NUMBA_CUDA_AVAILABLE = False

3. Writing Numba CUDA Kernels (@cuda.jit)#

Numba CUDA kernels follow standard CUDA programming conventions:

  • Decorated with @cuda.jit

  • Each thread identifies itself via cuda.grid(ndim) (equivalent to blockIdx * blockDim + threadIdx)

  • Arrays received as Numba device arrays (zero-copy from JAX GPU memory)

  • Results are written in-place into output arrays (no return values)

3.1 Basic Element-wise Kernel#

if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def elementwise_relu_kernel(x, out):
        """Element-wise ReLU: out[i] = max(x[i], 0)."""
        i = cuda.grid(1)          # global thread index
        if i < out.size:          # bounds check
            out[i] = max(x[i], 0.0)

    @cuda.jit
    def elementwise_sigmoid_kernel(x, out):
        """Element-wise sigmoid: out[i] = 1 / (1 + exp(-x[i]))."""
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = 1.0 / (1.0 + math.exp(-x[i]))

    print("Kernels defined:", elementwise_relu_kernel, elementwise_sigmoid_kernel)

3.2 Kernel with Shared Memory#

Shared memory is fast, on-chip memory shared between threads in the same block. It is useful for reduction operations or when threads need to communicate.

if NUMBA_CUDA_AVAILABLE:
    BLOCK_SIZE = 256

    @cuda.jit
    def block_sum_kernel(x, block_sums):
        """
        Computes the sum of each block's elements using shared memory reduction.
        block_sums[blockIdx.x] = sum of x elements processed by block blockIdx.x.
        """
        shared = cuda.shared.array(BLOCK_SIZE, dtype=numba.float32)

        tx  = cuda.threadIdx.x
        pos = cuda.grid(1)

        # Load into shared memory
        shared[tx] = x[pos] if pos < x.size else 0.0
        cuda.syncthreads()

        # Parallel reduction in shared memory
        stride = BLOCK_SIZE // 2
        while stride > 0:
            if tx < stride:
                shared[tx] += shared[tx + stride]
            cuda.syncthreads()
            stride //= 2

        # Thread 0 writes the block result
        if tx == 0:
            block_sums[cuda.blockIdx.x] = shared[0]

    print("Shared memory kernel defined.")

4. numba_cuda_kernel – Single-Kernel Wrapper#

numba_cuda_kernel wraps a single @cuda.jit kernel so it can be called with JAX GPU arrays. The kernel receives Numba CUDA device arrays (zero-copy from JAX device memory) and writes results into the output buffer.

Function signature:

numba_cuda_kernel(
    kernel,                    # @cuda.jit decorated function
    outs,                      # jax.ShapeDtypeStruct or list thereof
    *,
    grid=None, block=None,     # explicit CUDA launch config
    launch_dims=None,          # OR total threads (auto grid/block)
    threads_per_block=256,     # only used with launch_dims
    shared_mem=0,              # dynamic shared memory bytes
) -> callable

The kernel function signature must be:

kernel(input1, input2, ..., output1, output2, ...)

Inputs first, then outputs – all as Numba device arrays.

if NUMBA_CUDA_AVAILABLE:
    N = 1024
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)

    # Wrap the ReLU kernel
    relu_fn = numba_cuda_kernel(
        elementwise_relu_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        grid=4,
        block=256,
    )

    result  = relu_fn(x)
    expected = jnp.maximum(x, 0.0)
    print("ReLU max error:", float(jnp.max(jnp.abs(result - expected))))

    # Wrap the sigmoid kernel using launch_dims (auto grid/block)
    sigmoid_fn = numba_cuda_kernel(
        elementwise_sigmoid_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,           # launch exactly N threads
        threads_per_block=128,
    )

    result_sig  = sigmoid_fn(x)
    expected_sig = 1.0 / (1.0 + jnp.exp(-x))
    print("Sigmoid max error:", float(jnp.max(jnp.abs(result_sig - expected_sig))))

4.1 Multiple Outputs#

if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def split_kernel(x, pos_out, neg_out):
        """Split positive and negative parts of x."""
        i = cuda.grid(1)
        if i < x.size:
            v = x[i]
            pos_out[i] = max(v, 0.0)
            neg_out[i] = min(v, 0.0)

    N = 512
    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)

    split_fn = numba_cuda_kernel(
        split_kernel,
        outs=[
            jax.ShapeDtypeStruct((N,), jnp.float32),  # pos_out
            jax.ShapeDtypeStruct((N,), jnp.float32),  # neg_out
        ],
        launch_dims=N,
    )

    pos, neg = split_fn(x)

    print("pos_out[:5]:", pos[:5])   # max(x, 0)
    print("neg_out[:5]:", neg[:5])   # min(x, 0)
    print("pos + neg == x:", bool(jnp.allclose(pos + neg, x)))

4.2 JIT Compatibility#

if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def add_kernel(x, y, out):
        i = cuda.grid(1)
        if i < out.size:
            out[i] = x[i] + y[i]

    N = 256
    add_fn = numba_cuda_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,
    )

    @jax.jit
    def jitted_add(a, b):
        return add_fn(a, b)

    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32) * 2.0

    r = jitted_add(a, b)
    print("JIT add max error:", float(jnp.max(jnp.abs(r - (a + b)))))

    # Call multiple times (JIT is amortized after first call)
    for _ in range(5):
        r = jitted_add(a, b)
    print("Multiple JIT calls OK:", bool(jnp.allclose(r, a + b)))

5. Launch Configuration: grid / block vs. launch_dims#

CUDA kernels need a grid/block decomposition specifying how many threads to launch.

Option A – explicit grid and block:

numba_cuda_kernel(kernel, outs=..., grid=8, block=128)
# launches 8 blocks × 128 threads = 1024 total threads

Option B – launch_dims (auto-compute):

numba_cuda_kernel(kernel, outs=..., launch_dims=1024, threads_per_block=256)
# auto: block=256, grid=ceil(1024/256)=4

2D / 3D launches:

# 2D: launch M×N threads
numba_cuda_kernel(kernel, outs=..., launch_dims=(M, N))
# auto: block=(16,16), grid=(ceil(M/16), ceil(N/16))
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def matmul_element_kernel(A, B, C):
        """C[i,j] = A[i,j] * B[i,j]  (element-wise, 2D grid)."""
        i, j = cuda.grid(2)
        if i < C.shape[0] and j < C.shape[1]:
            C[i, j] = A[i, j] * B[i, j]

    M, N = 64, 64
    A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N)
    B = jnp.ones((M, N), dtype=jnp.float32) * 2.0

    hadamard_fn = numba_cuda_kernel(
        matmul_element_kernel,
        outs=jax.ShapeDtypeStruct((M, N), jnp.float32),
        launch_dims=(M, N),      # 2D launch
    )

    C = hadamard_fn(A, B)
    print("2D kernel max error:", float(jnp.max(jnp.abs(C - A * B))))

6. numba_cuda_callable – Multi-Kernel Wrapper#

Sometimes one kernel is not enough. For example, a reduction may need two passes, or a pipeline may require a temporary device buffer between stages.

numba_cuda_callable wraps an arbitrary Python function that can:

  • Launch multiple @cuda.jit kernels

  • Allocate temporary device memory with cuda.device_array

  • Use the XLA-managed CUDA stream (passed as the last argument)

Required function signature:

def my_func(input1, input2, ..., output1, output2, ..., stream):
    # input* and output* are Numba CUDA device arrays
    # stream is a Numba CUDA stream from XLA
    ...
if NUMBA_CUDA_AVAILABLE:
    # ---- Kernel 1: element-wise square ----
    @cuda.jit
    def square_kernel(x, temp):
        i = cuda.grid(1)
        if i < temp.size:
            temp[i] = x[i] * x[i]

    # ---- Kernel 2: element-wise square root ----
    @cuda.jit
    def sqrt_kernel(temp, out):
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = math.sqrt(temp[i])

    # ---- Multi-kernel callable: |x| = sqrt(x^2) ----
    def abs_via_two_kernels(x, out, stream):
        """
        Compute |x| using two kernels and a temporary buffer.
        Demonstrates multi-kernel pipeline with device allocation.
        """
        n = x.shape[0]
        threads = 256
        blocks  = (n + threads - 1) // threads

        # Temporary buffer on the GPU (freed after this function returns)
        temp = cuda.device_array(n, dtype=np.float32)

        # Launch both kernels on the XLA-managed stream
        square_kernel[blocks, threads, stream](x, temp)
        sqrt_kernel  [blocks, threads, stream](temp, out)

    N = 512
    x = jnp.linspace(-5.0, 5.0, N, dtype=jnp.float32)

    abs_fn = numba_cuda_callable(
        abs_via_two_kernels,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    result   = abs_fn(x)
    expected = jnp.abs(x)
    print("Multi-kernel |x| max error:", float(jnp.max(jnp.abs(result - expected))))
    print("First 5 values :", result[:5])
    print("Expected |x|   :", expected[:5])

7. Registering with XLACustomKernel#

For production use, register your Numba CUDA kernel as a backend of an XLACustomKernel primitive. This integrates the kernel into JAX’s lowering pipeline and allows mixing with other backends (e.g., a Numba CPU fallback).

The kernel generator pattern is the same as for Warp (see Tutorial 6): it is a Python callable that receives keyword arguments (forwarded from primitive.bind) and returns a concrete kernel function.

if NUMBA_CUDA_AVAILABLE:
    # -----------------------------------------------------------------------
    # Kernel generator for element-wise leaky ReLU:
    #   out[i] = x[i] if x[i] > 0 else alpha * x[i]
    # 'alpha' is passed at trace time via kwargs.
    # -----------------------------------------------------------------------

    def leaky_relu_numba_cuda_generator(**kwargs):
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]
        alpha    = float(kwargs.get('alpha', 0.01))

        @cuda.jit
        def leaky_relu_kern(x, out):
            i = cuda.grid(1)
            if i < out.size:
                v = x[i]
                out[i] = v if v > 0.0 else alpha * v

        def kernel(x):
            return numba_cuda_kernel(
                leaky_relu_kern,
                outs=out_info,
                launch_dims=n,
            )(x)

        return kernel

    # Register the primitive
    leaky_relu_op = XLACustomKernel('tutorial_numba_cuda_leaky_relu')
    leaky_relu_op.def_numba_cuda_kernel(leaky_relu_numba_cuda_generator)

    print("Registered:", leaky_relu_op._kernels)
if NUMBA_CUDA_AVAILABLE:
    N = 256
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)

    @jax.jit
    def jitted_leaky_relu(x, alpha=0.1):
        return leaky_relu_op(
            x,
            outs=[jax.ShapeDtypeStruct(x.shape, x.dtype)],
            alpha=alpha,
        )[0]

    r = jitted_leaky_relu(x, alpha=0.1)

    expected = jnp.where(x > 0, x, 0.1 * x)
    print("Leaky ReLU max error:", float(jnp.max(jnp.abs(r - expected))))
    print("Values around 0     :", r[N//2 - 3 : N//2 + 3])

8. Neuroscience Example: Parallel Spike Threshold Detection#

A common operation in spiking neural networks: given membrane potentials V and a threshold V_th, detect which neurons fire and reset their potentials in-place.

We implement this as two Numba CUDA kernels fused via numba_cuda_callable:

  1. Detect spikes: spikes[i] = (V[i] >= V_th)

  2. Reset potentials: V_reset[i] = spikes[i] ? V_rest : V[i]

if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def detect_spikes_kernel(V, V_th, spikes):
        """spikes[i] = 1 if V[i] >= V_th[0] else 0."""
        i = cuda.grid(1)
        if i < V.size:
            spikes[i] = 1 if V[i] >= V_th[0] else 0

    @cuda.jit
    def reset_potential_kernel(V, spikes, V_rest, V_out):
        """V_out[i] = V_rest[0] if spikes[i] else V[i]."""
        i = cuda.grid(1)
        if i < V.size:
            V_out[i] = V_rest[0] if spikes[i] else V[i]

    def lif_step(V, V_th, V_rest, spikes_out, V_out, stream):
        """One LIF step: detect spikes and reset membrane potential."""
        n       = V.shape[0]
        threads = 256
        blocks  = (n + threads - 1) // threads

        detect_spikes_kernel[blocks, threads, stream](V, V_th, spikes_out)
        reset_potential_kernel[blocks, threads, stream](V, spikes_out, V_rest, V_out)

    print("LIF step function defined.")
if NUMBA_CUDA_AVAILABLE:
    N_NEURONS = 10_000

    rng = np.random.default_rng(0)
    V      = jnp.array(rng.uniform(-75.0, -50.0, N_NEURONS).astype(np.float32))
    V_th   = jnp.array([-55.0], dtype=jnp.float32)   # threshold (mV)
    V_rest = jnp.array([-70.0], dtype=jnp.float32)   # reset potential (mV)

    lif_fn = numba_cuda_callable(
        lif_step,
        outs=[
            jax.ShapeDtypeStruct((N_NEURONS,), jnp.int32),    # spikes
            jax.ShapeDtypeStruct((N_NEURONS,), jnp.float32),  # V_out
        ],
    )

    spikes, V_out = lif_fn(V, V_th, V_rest)

    # Verify
    expected_spikes = (V >= V_th[0]).astype(jnp.int32)
    expected_V_out  = jnp.where(expected_spikes, V_rest[0], V)

    print(f"Neurons: {N_NEURONS}")
    print(f"Spikes detected: {int(spikes.sum())} / {N_NEURONS}  ({100*float(spikes.mean()):.1f}%)")
    print(f"Spike detection error: {int(jnp.sum(spikes != expected_spikes))}")
    print(f"V_out max error: {float(jnp.max(jnp.abs(V_out - expected_V_out))):.6f} mV")
if NUMBA_CUDA_AVAILABLE:
    import time

    @jax.jit
    def numba_lif_step(V, V_th, V_rest):
        return lif_fn(V, V_th, V_rest)

    @jax.jit
    def jax_lif_step(V, V_th, V_rest):
        spikes = (V >= V_th[0]).astype(jnp.int32)
        V_out  = jnp.where(spikes, V_rest[0], V)
        return spikes, V_out

    # Warm up
    jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
    jax.block_until_ready(jax_lif_step(V, V_th, V_rest))

    N_TRIALS = 500

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
    numba_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(jax_lif_step(V, V_th, V_rest))
    jax_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Numba CUDA LIF step : {numba_time:.3f} ms")
    print(f"JAX native LIF step : {jax_time:.3f} ms")

9. Performance Tips#

9.1 Choose the right launch configuration#

  • A warp is 32 threads; use block sizes that are multiples of 32 (128, 256, 512).

  • Too few threads per block wastes warp slots; too many limits occupancy.

  • Use launch_dims for simple 1-D problems; specify explicit grid/block for fine control.

9.2 Minimize thread divergence#

Threads in the same warp execute in lock-step. Conditional branches that differ between threads (divergence) serialize execution. Where possible, arrange data so threads in the same warp take the same branch.

9.3 Use shared memory for reuse#

If multiple threads access the same data, load it into shared memory first and synchronize with cuda.syncthreads() before reading.

9.4 Cache the wrapped callable#

Each call to numba_cuda_kernel / numba_cuda_callable registers a new FFI target. Create the wrapped callable once at module level and reuse it.

if NUMBA_CUDA_AVAILABLE:
    # Good pattern: create the callable once and reuse
    @cuda.jit
    def exp_decay_kernel(x, decay, out):
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = x[i] * math.exp(-decay[0])

    N = 1024
    # Create once at definition time
    _exp_decay_fn = numba_cuda_kernel(
        exp_decay_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,
    )

    @jax.jit
    def apply_exp_decay(x, decay):
        return _exp_decay_fn(x, decay)

    x     = jnp.ones(N, dtype=jnp.float32)
    decay = jnp.array([0.1], dtype=jnp.float32)

    r = apply_exp_decay(x, decay)
    print("Exp decay result:", float(r[0]), "| Expected:", float(np.exp(-0.1)))

10. Summary#

In this tutorial we covered:

  1. @cuda.jit – Write GPU kernels in Python; use cuda.grid(ndim) for thread indices and bounds-check with if i < size.

  2. numba_cuda_kernel – Single-kernel JAX wrapper. Specify launch config via (grid, block) for explicit control or launch_dims for automatic decomposition. Supports 1-D, 2-D, and 3-D launches.

  3. numba_cuda_callable – Multi-kernel JAX wrapper. Your Python function receives Numba device arrays and the XLA-managed CUDA stream; it can launch multiple kernels and allocate temporary device memory.

  4. XLACustomKernel.def_numba_cuda_kernel – Register a kernel generator as the GPU backend of a multi-backend custom JAX primitive.

  5. Neuroscience application – LIF spike detection and reset implemented as a two-kernel callable, showing the key pattern for fused GPU operations.

Next Steps#

  • Tutorial 6: Custom GPU operators with Warp

  • Tutorial 8: Custom CPU operators with Numba (@numba.njit)

References#