Custom GPU Operators with Warp#
This tutorial shows how to write custom GPU kernels using NVIDIA Warp and integrate them into the BrainEvent / JAX ecosystem.
NVIDIA Warp is a Python framework for high-performance
GPU kernel authoring. Kernels are written in Python-like syntax, JIT-compiled to CUDA PTX,
and can be called seamlessly from JAX via warp.jax_experimental.jax_kernel.
Contents#
Why Warp?
Installation and Imports
Writing Your First Warp Kernel
Type Annotations –
jaxinfo_to_warpinfo/jaxtype_to_warptypeCalling Warp Kernels from JAX
In-place (accumulation) vs. Pure-output Patterns
Registering Kernels with
XLACustomKernelNeuroscience Example: Sparse Synaptic Input Accumulation
Summary
1. Why Warp?#
Feature |
Warp |
Raw CUDA C++ |
|---|---|---|
Language |
Python-like syntax |
C++ |
Compilation |
Automatic JIT |
Manual |
JAX integration |
Built-in ( |
Manual XLA FFI |
Autodiff |
Limited (scalar ops) |
Manual |
Best for |
Custom GPU ops in Python |
Maximum control |
Warp is the recommended path when you want GPU acceleration without leaving Python.
BrainEvent’s XLACustomKernel infrastructure makes it trivial to register a Warp kernel
as a backend for any custom JAX primitive.
Requirements:
NVIDIA GPU with CUDA
pip install warp-lang(installs asimport warp)JAX with GPU support (
pip install jax[cuda12])
2. Installation and Imports#
# Install if needed
# !pip install warp-lang -U
# !pip install brainevent[cuda12] -U
import jax
import jax.numpy as jnp
import numpy as np
import brainevent
from brainevent import XLACustomKernel, jaxinfo_to_warpinfo, jaxtype_to_warptype
print(f"JAX version : {jax.__version__}")
print(f"JAX backend : {jax.default_backend()}")
print(f"BrainEvent : {brainevent.__version__}")
try:
import warp
from warp.jax_experimental import jax_kernel
warp.config.quiet = True
print(f"Warp version : {warp.__version__}")
WARP_AVAILABLE = True
except ImportError:
print("Warp not installed. Run: pip install warp-lang")
WARP_AVAILABLE = False
3. Writing Your First Warp Kernel#
A Warp kernel is a Python function decorated with @warp.kernel. Key rules:
No Python data structures – only Warp scalars and arrays
Thread index obtained via
warp.tid()(replacesblockIdx * blockDim + threadIdxin CUDA C)Array types must be annotated using
warp.array(dtype=..., ndim=...)The kernel body runs once per thread, so you typically launch one thread per element
3.1 Element-wise ReLU#
if WARP_AVAILABLE:
@warp.kernel
def relu_kernel(
x: warp.array(dtype=warp.float32, ndim=1),
out: warp.array(dtype=warp.float32, ndim=1),
):
i = warp.tid() # thread index = element index
out[i] = warp.max(x[i], warp.float32(0.0))
print("relu_kernel defined successfully")
print(f"Kernel type: {type(relu_kernel)}")
3.2 Calling the Kernel via jax_kernel#
jax_kernel wraps a Warp kernel so it can be called with JAX arrays.
Signature:
fn = jax_kernel(
warp_kernel,
launch_dims=[n], # total threads to launch per dimension
num_outputs=1, # how many output arrays the kernel writes
output_dims={'out': (n,)} # shape of each output (allocated by Warp)
)
result = fn(x) # pass only input arrays; outputs are returned
There are two output modes:
output_dims– Warp allocates the output buffer; you only pass inputs.in_out_argnames– You pass a pre-allocated (e.g.,jnp.zeros) buffer; Warp writes into it.
if WARP_AVAILABLE:
N = 1024
x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)
# Build the JAX-callable wrapper
relu_fn = jax_kernel(
relu_kernel,
launch_dims=[N],
num_outputs=1,
output_dims={'out': (N,)},
)
# Call it – returns a tuple of output arrays
(result,) = relu_fn(x)
# Verify against JAX reference
expected = jnp.maximum(x, 0.0)
print("Max error:", float(jnp.max(jnp.abs(result - expected))))
print("First 8 values:", result[:8])
4. Type Annotations – jaxinfo_to_warpinfo / jaxtype_to_warptype#
When embedding a Warp kernel inside a kernel generator (a function that receives shape/dtype information at trace time), you need to create the Warp type annotations dynamically. BrainEvent provides two helpers:
from brainevent import jaxinfo_to_warpinfo, jaxtype_to_warptype
# Convert jax.ShapeDtypeStruct -> warp.array(dtype=..., ndim=...)
warp_arr_type = jaxinfo_to_warpinfo(jax.ShapeDtypeStruct((1024,), jnp.float32))
# Convert numpy/JAX dtype -> warp scalar type
warp_scalar_type = jaxtype_to_warptype(jnp.float32) # -> warp.float32
These utilities support: float16, float32, float64, int8–int64, uint8–uint64, bool.
if WARP_AVAILABLE:
import jax
for jax_dtype in [jnp.float32, jnp.float64, jnp.int32, jnp.bool_]:
warp_type = jaxtype_to_warptype(jax_dtype)
info = jax.ShapeDtypeStruct((8, 4), jax_dtype)
warp_arr = jaxinfo_to_warpinfo(info)
print(f" jnp.{jax_dtype.__name__:<8} -> warp scalar: {warp_type} | warp array: {warp_arr}")
5. Kernel Generators – Dynamic Kernel Construction#
When integrating with XLACustomKernel, kernels are not defined statically.
Instead you define a kernel generator: a plain Python function that receives
shape/dtype keyword arguments (forwarded from primitive.bind) and returns a
callable that runs the actual computation.
This pattern allows the same generator to handle different dtypes and shapes without re-registering the primitive.
5.1 Template for a Warp Kernel Generator#
if WARP_AVAILABLE:
def my_relu_kernel_generator(**kwargs):
"""
Kernel generator for element-wise ReLU.
kwargs contains whatever was passed to XLACustomKernel.__call__,
e.g. kwargs['outs'] = [jax.ShapeDtypeStruct(shape, dtype)]
"""
# --- 1. Extract shape/dtype information from kwargs ---------------
out_info = kwargs['outs'][0] # jax.ShapeDtypeStruct
n = out_info.shape[0]
# --- 2. Build Warp type annotations dynamically -------------------
x_warp_type = jaxinfo_to_warpinfo(out_info) # same dtype for input
out_warp_type = jaxinfo_to_warpinfo(out_info)
# --- 3. Define the @warp.kernel with dynamic type annotations -----
@warp.kernel
def relu_kern(
x: x_warp_type,
out: out_warp_type,
):
i = warp.tid()
out[i] = warp.max(x[i], out_warp_type.dtype(0.0))
# --- 4. Return the concrete kernel function -----------------------
def kernel(x):
fn = jax_kernel(
relu_kern,
launch_dims=[n],
num_outputs=1,
output_dims={'out': (n,)},
)
return fn(x)
return kernel
print("Kernel generator defined.")
6. In-place (Accumulation) vs. Pure-output Patterns#
Many neuroscience operations scatter-add values into an output buffer
(e.g., synaptic current accumulation). Warp handles this via atomic operations
and the in_out_argnames mechanism.
6.1 Pure output (Warp allocates)#
fn = jax_kernel(kernel, launch_dims=[N], num_outputs=1, output_dims={'out': (N,)})
result, = fn(x) # only pass inputs
6.2 In-place / accumulation (caller provides buffer)#
fn = jax_kernel(kernel, launch_dims=[M], num_outputs=1, in_out_argnames=['acc'])
result, = fn(x, jnp.zeros((N,), dtype)) # pass input THEN the initial output buffer
The in_out_argnames list tells Warp which arguments are both input and output,
enabling atomic operations inside the kernel.
if WARP_AVAILABLE:
# Scatter-add example: for each non-zero element in 'values',
# add values[i] * scale to acc[targets[i]].
N_SRC = 512 # source elements
N_DST = 128 # destination (output) size
@warp.kernel
def scatter_add_kernel(
values: warp.array(dtype=warp.float32, ndim=1),
targets: warp.array(dtype=warp.int32, ndim=1),
scale: warp.array(dtype=warp.float32, ndim=1), # 1-element array
acc: warp.array(dtype=warp.float32, ndim=1), # in-place output
):
i = warp.tid()
# Atomic add is thread-safe – multiple threads may target the same slot
warp.atomic_add(acc, targets[i], values[i] * scale[0])
# Create test data
rng = np.random.default_rng(0)
values = jnp.array(rng.random(N_SRC).astype(np.float32))
targets = jnp.array(rng.integers(0, N_DST, N_SRC).astype(np.int32))
scale = jnp.array([2.0], dtype=jnp.float32)
# Build callable with in-place accumulator
scatter_fn = jax_kernel(
scatter_add_kernel,
launch_dims=[N_SRC],
num_outputs=1,
in_out_argnames=['acc'], # 'acc' is both input and output
)
# Run: pass (values, targets, scale, initial_acc)
init_acc = jnp.zeros(N_DST, dtype=jnp.float32)
(result,) = scatter_fn(values, targets, scale, init_acc)
# Verify with NumPy reference
ref = np.zeros(N_DST, dtype=np.float32)
np.add.at(ref, np.array(targets), np.array(values) * 2.0)
print("Scatter-add max error:", float(jnp.max(jnp.abs(result - jnp.array(ref)))))
print("Result sum:", float(result.sum()), "| Expected:", float(ref.sum()))
7. Registering Kernels with XLACustomKernel#
XLACustomKernel is BrainEvent’s central abstraction for multi-backend custom
JAX primitives. It lets you register different backend implementations
(Warp, Numba, Pallas, …) for the same logical operation, then dispatch to the
right one at runtime.
Workflow:
Create an
XLACustomKernelinstance with a unique nameRegister your Warp kernel generator via
def_warp_kernel()(Optionally) register a CPU fallback via
def_numba_kernel()Call the primitive with
kernel(x, outs=[...])
if WARP_AVAILABLE:
# -----------------------------------------------------------------------
# Step 1: Define the kernel generator
# -----------------------------------------------------------------------
def warp_scale_add_generator(**kwargs):
"""Element-wise: out[i] = a[i] * b[i] + c[i]"""
out_info = kwargs['outs'][0]
n = out_info.shape[0]
t = jaxinfo_to_warpinfo(out_info)
@warp.kernel
def kern(
a: t,
b: t,
c: t,
out: t,
):
i = warp.tid()
out[i] = a[i] * b[i] + c[i]
def run(a, b, c):
fn = jax_kernel(kern, launch_dims=[n], num_outputs=1,
output_dims={'out': (n,)})
return fn(a, b, c)
return run
# -----------------------------------------------------------------------
# Step 2: Create and register the primitive
# -----------------------------------------------------------------------
scale_add_op = XLACustomKernel('tutorial_warp_scale_add')
scale_add_op.def_warp_kernel(warp_scale_add_generator)
print("Registered backends:", scale_add_op._kernels)
print("Default backends :", scale_add_op.defaults)
if WARP_AVAILABLE:
# -----------------------------------------------------------------------
# Step 3: Call the primitive
# -----------------------------------------------------------------------
N = 256
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.full(N, 2.0, dtype=jnp.float32)
c = jnp.ones(N, dtype=jnp.float32)
out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)
result = scale_add_op(a, b, c, outs=[out_spec])
expected = a * b + c
print("Max error:", float(jnp.max(jnp.abs(result[0] - expected))))
print("First 5 :", result[0][:5])
# -----------------------------------------------------------------------
# Step 4: Use inside jax.jit (the primitive is JIT-compatible)
# -----------------------------------------------------------------------
@jax.jit
def jitted_op(a, b, c):
return scale_add_op(a, b, c, outs=[jax.ShapeDtypeStruct(a.shape, a.dtype)])[0]
r = jitted_op(a, b, c)
print("JIT result matches:", bool(jnp.allclose(r, expected)))
8. Neuroscience Example: Sparse Synaptic Input Accumulation#
A classic operation in spiking neural network simulation:
given a binary spike vector spikes (shape [N_pre]) and a CSR weight matrix
(data, indices, indptr), compute the postsynaptic current
We implement this with a Warp kernel that:
Iterates over pre-synaptic neurons in parallel
Skips silent neurons (no spike)
Atomically accumulates weights into the postsynaptic current buffer
if WARP_AVAILABLE:
def csr_binary_mv_warp_generator(**kwargs):
"""
Kernel generator for CSR × binary-vector multiplication.
Signature: kernel(weights, indices, indptr, spikes) -> post_current
"""
weight_info = kwargs['weight_info']
spike_info = kwargs['spike_info']
indices_info = kwargs['indices_info']
indptr_info = kwargs['indptr_info']
n_pre = indptr_info.shape[0] - 1
n_post = kwargs['n_post']
out_dtype = kwargs['outs'][0].dtype
# Build Warp type descriptors
w_type = jaxinfo_to_warpinfo(weight_info)
idx_type = jaxinfo_to_warpinfo(indices_info)
indptr_type = jaxinfo_to_warpinfo(indptr_info)
spk_type = jaxinfo_to_warpinfo(spike_info)
out_type = warp.array(dtype=jaxtype_to_warptype(out_dtype), ndim=1)
@warp.kernel
def mv_kern(
weights: w_type,
indices: idx_type,
indptr: indptr_type,
spikes: spk_type,
posts: out_type,
):
i = warp.tid() # one thread per pre-synaptic neuron
if spikes[i]: # skip silent neurons
w = weights[0] # scalar weight (homogeneous)
for j in range(indptr[i], indptr[i + 1]):
warp.atomic_add(posts, indices[j], w)
def kernel(weights, indices, indptr, spikes):
fn = jax_kernel(
mv_kern,
launch_dims=[n_pre],
num_outputs=1,
in_out_argnames=['posts'],
)
return fn(weights, indices, indptr, spikes,
jnp.zeros(n_post, dtype=out_dtype))
return kernel
print("CSR binary MV kernel generator defined.")
if WARP_AVAILABLE:
import scipy.sparse as sp
# Build a random CSR connectivity matrix
N_PRE = 1000
N_POST = 500
PROB = 0.05 # 5 % connection probability
W = 0.1 # homogeneous weight
rng = np.random.default_rng(42)
dense = (rng.random((N_PRE, N_POST)) < PROB).astype(np.float32) * W
csr = sp.csr_matrix(dense)
data = jnp.array([W], dtype=jnp.float32) # scalar weight
indices = jnp.array(csr.indices, dtype=jnp.int32)
indptr = jnp.array(csr.indptr, dtype=jnp.int32)
# Generate binary spikes (10 % firing rate)
spikes = jnp.array(rng.random(N_PRE) < 0.10, dtype=jnp.bool_)
# Register the primitive
csr_mv_op = XLACustomKernel('tutorial_warp_csr_mv')
csr_mv_op.def_warp_kernel(csr_binary_mv_warp_generator)
# Build output spec and call
out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)
result = csr_mv_op(
data, indices, indptr, spikes,
outs=[out_spec],
# extra kwargs forwarded to the generator:
weight_info = jax.ShapeDtypeStruct(data.shape, data.dtype),
spike_info = jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
indptr_info = jax.ShapeDtypeStruct(indptr.shape, indptr.dtype),
n_post = N_POST,
)
# Reference: dense matmul
spikes_f = spikes.astype(jnp.float32)
expected = spikes_f @ jnp.array(dense)
print(f"Network: {N_PRE} pre -> {N_POST} post | {int(spikes.sum())} spikes")
print(f"Max error vs dense reference: {float(jnp.max(jnp.abs(result[0] - expected))):.6f}")
print(f"Post current range: [{float(result[0].min()):.3f}, {float(result[0].max()):.3f}]")
if WARP_AVAILABLE:
import time
@jax.jit
def warp_mv(data, indices, indptr, spikes):
return csr_mv_op(
data, indices, indptr, spikes,
outs=[out_spec],
weight_info = jax.ShapeDtypeStruct(data.shape, data.dtype),
spike_info = jax.ShapeDtypeStruct(spikes.shape, spikes.dtype),
indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
indptr_info = jax.ShapeDtypeStruct(indptr.shape, indptr.dtype),
n_post = N_POST,
)[0]
@jax.jit
def dense_mv(spikes_f, dense):
return spikes_f @ dense
# Warm up
jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))
N_TRIALS = 200
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
warp_time = (time.time() - t0) / N_TRIALS * 1000
t0 = time.time()
for _ in range(N_TRIALS):
jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))
dense_time = (time.time() - t0) / N_TRIALS * 1000
print(f"Warp sparse kernel : {warp_time:.3f} ms")
print(f"JAX dense matmul : {dense_time:.3f} ms")
print(f"Speedup : {dense_time / warp_time:.2f}x")
9. Multiple Backends with XLACustomKernel#
You can register multiple backends for the same operation and switch at runtime.
# CPU fallback using Numba (demonstrated here even if GPU is unavailable)
try:
import numba
from brainevent import numba_kernel
NUMBA_AVAILABLE = True
except ImportError:
NUMBA_AVAILABLE = False
if NUMBA_AVAILABLE:
@numba.njit(parallel=True)
def _scale_add_numba(a, b, c, out):
for i in numba.prange(out.size):
out[i] = a[i] * b[i] + c[i]
def numba_scale_add_generator(**kwargs):
out_info = kwargs['outs'][0]
def kernel(a, b, c):
return numba_kernel(_scale_add_numba, outs=out_info)(a, b, c)
return kernel
# Create op with both GPU (Warp) and CPU (Numba) backends
multi_backend_op = XLACustomKernel('tutorial_multi_backend_scale_add')
if WARP_AVAILABLE:
multi_backend_op.def_warp_kernel(warp_scale_add_generator) # GPU
multi_backend_op.def_numba_kernel(numba_scale_add_generator) # CPU
print("Registered backends:", list(multi_backend_op._kernels.keys()))
# On GPU, Warp is default; on CPU, Numba is used automatically
N = 128
a = jnp.arange(N, dtype=jnp.float32)
b = jnp.full(N, 3.0, dtype=jnp.float32)
c = jnp.ones(N, dtype=jnp.float32)
r = multi_backend_op(a, b, c, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])
print("Result matches:", bool(jnp.allclose(r[0], a * b + c)))
10. Summary#
In this tutorial we covered:
@warp.kernel– Write GPU kernels in Python-like syntax; usewarp.tid()for the thread index.jax_kernel– Wrap a Warp kernel so JAX can call it withjax.Arrayinputs.output_dimsmode: Warp allocates the output buffer.in_out_argnamesmode: caller provides the initial buffer (needed for atomic accumulation).
jaxinfo_to_warpinfo/jaxtype_to_warptype– Convert JAX dtype/shape info to Warp types for dynamic kernel construction inside kernel generators.XLACustomKernel.def_warp_kernel– Register a Warp kernel generator as the GPU backend of a multi-backend custom JAX primitive.Neuroscience application – Sparse CSR × binary-spike matrix-vector product implemented with Warp atomic operations, demonstrating the key pattern used throughout BrainEvent.
Next Steps#
Tutorial 7: Custom GPU operators with Numba CUDA (
@cuda.jit)Tutorial 8: Custom CPU operators with Numba (
@numba.njit)