Custom GPU Operators with Warp#

This tutorial shows how to write custom GPU kernels using NVIDIA Warp and integrate them into the BrainEvent / JAX ecosystem.

NVIDIA Warp is a Python framework for high-performance GPU kernel authoring. Kernels are written in Python-like syntax, JIT-compiled to CUDA PTX, and can be called seamlessly from JAX via warp.jax_experimental.jax_kernel.

Contents#

  1. Why Warp?

  2. Installation and Imports

  3. Writing Your First Warp Kernel

  4. Type Annotations – jaxinfo_to_warpinfo / jaxtype_to_warptype

  5. Calling Warp Kernels from JAX

  6. In-place (accumulation) vs. Pure-output Patterns

  7. Registering Kernels with XLACustomKernel

  8. Neuroscience Example: Sparse Synaptic Input Accumulation

  9. Summary

1. Why Warp?#

Feature

Warp

Raw CUDA C++

Language

Python-like syntax

C++

Compilation

Automatic JIT

Manual

JAX integration

Built-in (jax_kernel)

Manual XLA FFI

Autodiff

Limited (scalar ops)

Manual

Best for

Custom GPU ops in Python

Maximum control

Warp is the recommended path when you want GPU acceleration without leaving Python. BrainEvent’s XLACustomKernel infrastructure makes it trivial to register a Warp kernel as a backend for any custom JAX primitive.

Requirements:

  • NVIDIA GPU with CUDA

  • pip install warp-lang (installs as import warp)

  • JAX with GPU support (pip install jax[cuda12])

2. Installation and Imports#

# Install if needed
# !pip install warp-lang -U
# !pip install brainevent[cuda12] -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, jaxinfo_to_warpinfo, jaxtype_to_warptype

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    import warp
    from warp.jax_experimental import jax_kernel
    warp.config.quiet = True
    print(f"Warp version   : {warp.__version__}")
    WARP_AVAILABLE = True
except ImportError:
    print("Warp not installed. Run: pip install warp-lang")
    WARP_AVAILABLE = False

3. Writing Your First Warp Kernel#

A Warp kernel is a Python function decorated with @warp.kernel. Key rules:

  • No Python data structures – only Warp scalars and arrays

  • Thread index obtained via warp.tid() (replaces blockIdx * blockDim + threadIdx in CUDA C)

  • Array types must be annotated using warp.array(dtype=..., ndim=...)

  • The kernel body runs once per thread, so you typically launch one thread per element

3.1 Element-wise ReLU#

if WARP_AVAILABLE:
    @warp.kernel
    def relu_kernel(
        x:   warp.array(dtype=warp.float32, ndim=1),
        out: warp.array(dtype=warp.float32, ndim=1),
    ):
        i = warp.tid()           # thread index = element index
        out[i] = warp.max(x[i], warp.float32(0.0))

    print("relu_kernel defined successfully")
    print(f"Kernel type: {type(relu_kernel)}")

3.2 Calling the Kernel via jax_kernel#

jax_kernel wraps a Warp kernel so it can be called with JAX arrays.

Signature:

fn = jax_kernel(
    warp_kernel,
    launch_dims=[n],         # total threads to launch per dimension
    num_outputs=1,           # how many output arrays the kernel writes
    output_dims={'out': (n,)} # shape of each output (allocated by Warp)
)
result = fn(x)  # pass only input arrays; outputs are returned

There are two output modes:

  • output_dims – Warp allocates the output buffer; you only pass inputs.

  • in_out_argnames – You pass a pre-allocated (e.g., jnp.zeros) buffer; Warp writes into it.

if WARP_AVAILABLE:
    N = 1024
    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)

    # Build the JAX-callable wrapper
    relu_fn = jax_kernel(
        relu_kernel,
        launch_dims=[N],
        num_outputs=1,
        output_dims={'out': (N,)},
    )

    # Call it – returns a tuple of output arrays
    (result,) = relu_fn(x)

    # Verify against JAX reference
    expected = jnp.maximum(x, 0.0)
    print("Max error:", float(jnp.max(jnp.abs(result - expected))))
    print("First 8 values:", result[:8])

4. Type Annotations – jaxinfo_to_warpinfo / jaxtype_to_warptype#

When embedding a Warp kernel inside a kernel generator (a function that receives shape/dtype information at trace time), you need to create the Warp type annotations dynamically. BrainEvent provides two helpers:

from brainevent import jaxinfo_to_warpinfo, jaxtype_to_warptype

# Convert jax.ShapeDtypeStruct  ->  warp.array(dtype=..., ndim=...)
warp_arr_type = jaxinfo_to_warpinfo(jax.ShapeDtypeStruct((1024,), jnp.float32))

# Convert numpy/JAX dtype  ->  warp scalar type
warp_scalar_type = jaxtype_to_warptype(jnp.float32)  # -> warp.float32

These utilities support: float16, float32, float64, int8int64, uint8uint64, bool.

if WARP_AVAILABLE:
    import jax

    for jax_dtype in [jnp.float32, jnp.float64, jnp.int32, jnp.bool_]:
        warp_type = jaxtype_to_warptype(jax_dtype)
        info = jax.ShapeDtypeStruct((8, 4), jax_dtype)
        warp_arr = jaxinfo_to_warpinfo(info)
        print(f"  jnp.{jax_dtype.__name__:<8} -> warp scalar: {warp_type}  |  warp array: {warp_arr}")

5. Kernel Generators – Dynamic Kernel Construction#

When integrating with XLACustomKernel, kernels are not defined statically. Instead you define a kernel generator: a plain Python function that receives shape/dtype keyword arguments (forwarded from primitive.bind) and returns a callable that runs the actual computation.

This pattern allows the same generator to handle different dtypes and shapes without re-registering the primitive.

5.1 Template for a Warp Kernel Generator#

if WARP_AVAILABLE:
    def my_relu_kernel_generator(**kwargs):
        """
        Kernel generator for element-wise ReLU.

        kwargs contains whatever was passed to XLACustomKernel.__call__,
        e.g. kwargs['outs'] = [jax.ShapeDtypeStruct(shape, dtype)]
        """
        # --- 1. Extract shape/dtype information from kwargs ---------------
        out_info = kwargs['outs'][0]            # jax.ShapeDtypeStruct
        n = out_info.shape[0]

        # --- 2. Build Warp type annotations dynamically -------------------
        x_warp_type   = jaxinfo_to_warpinfo(out_info)   # same dtype for input
        out_warp_type = jaxinfo_to_warpinfo(out_info)

        # --- 3. Define the @warp.kernel with dynamic type annotations -----
        @warp.kernel
        def relu_kern(
            x:   x_warp_type,
            out: out_warp_type,
        ):
            i = warp.tid()
            out[i] = warp.max(x[i], out_warp_type.dtype(0.0))

        # --- 4. Return the concrete kernel function -----------------------
        def kernel(x):
            fn = jax_kernel(
                relu_kern,
                launch_dims=[n],
                num_outputs=1,
                output_dims={'out': (n,)},
            )
            return fn(x)

        return kernel

    print("Kernel generator defined.")

6. In-place (Accumulation) vs. Pure-output Patterns#

Many neuroscience operations scatter-add values into an output buffer (e.g., synaptic current accumulation). Warp handles this via atomic operations and the in_out_argnames mechanism.

6.1 Pure output (Warp allocates)#

fn = jax_kernel(kernel, launch_dims=[N], num_outputs=1, output_dims={'out': (N,)})
result, = fn(x)  # only pass inputs

6.2 In-place / accumulation (caller provides buffer)#

fn = jax_kernel(kernel, launch_dims=[M], num_outputs=1, in_out_argnames=['acc'])
result, = fn(x, jnp.zeros((N,), dtype))  # pass input THEN the initial output buffer

The in_out_argnames list tells Warp which arguments are both input and output, enabling atomic operations inside the kernel.

if WARP_AVAILABLE:
    # Scatter-add example: for each non-zero element in 'values',
    # add values[i] * scale to acc[targets[i]].

    N_SRC = 512   # source elements
    N_DST = 128   # destination (output) size

    @warp.kernel
    def scatter_add_kernel(
        values:  warp.array(dtype=warp.float32, ndim=1),
        targets: warp.array(dtype=warp.int32,   ndim=1),
        scale:   warp.array(dtype=warp.float32, ndim=1),  # 1-element array
        acc:     warp.array(dtype=warp.float32, ndim=1),  # in-place output
    ):
        i = warp.tid()
        # Atomic add is thread-safe – multiple threads may target the same slot
        warp.atomic_add(acc, targets[i], values[i] * scale[0])

    # Create test data
    rng = np.random.default_rng(0)
    values  = jnp.array(rng.random(N_SRC).astype(np.float32))
    targets = jnp.array(rng.integers(0, N_DST, N_SRC).astype(np.int32))
    scale   = jnp.array([2.0], dtype=jnp.float32)

    # Build callable with in-place accumulator
    scatter_fn = jax_kernel(
        scatter_add_kernel,
        launch_dims=[N_SRC],
        num_outputs=1,
        in_out_argnames=['acc'],       # 'acc' is both input and output
    )

    # Run: pass (values, targets, scale, initial_acc)
    init_acc = jnp.zeros(N_DST, dtype=jnp.float32)
    (result,) = scatter_fn(values, targets, scale, init_acc)

    # Verify with NumPy reference
    ref = np.zeros(N_DST, dtype=np.float32)
    np.add.at(ref, np.array(targets), np.array(values) * 2.0)
    print("Scatter-add max error:", float(jnp.max(jnp.abs(result - jnp.array(ref)))))
    print("Result sum:", float(result.sum()), "| Expected:", float(ref.sum()))

7. Registering Kernels with XLACustomKernel#

XLACustomKernel is BrainEvent’s central abstraction for multi-backend custom JAX primitives. It lets you register different backend implementations (Warp, Numba, Pallas, …) for the same logical operation, then dispatch to the right one at runtime.

Workflow:

  1. Create an XLACustomKernel instance with a unique name

  2. Register your Warp kernel generator via def_warp_kernel()

  3. (Optionally) register a CPU fallback via def_numba_kernel()

  4. Call the primitive with kernel(x, outs=[...])

if WARP_AVAILABLE:
    # -----------------------------------------------------------------------
    # Step 1: Define the kernel generator
    # -----------------------------------------------------------------------
    def warp_scale_add_generator(**kwargs):
        """Element-wise: out[i] = a[i] * b[i] + c[i]"""
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]
        t        = jaxinfo_to_warpinfo(out_info)

        @warp.kernel
        def kern(
            a:   t,
            b:   t,
            c:   t,
            out: t,
        ):
            i = warp.tid()
            out[i] = a[i] * b[i] + c[i]

        def run(a, b, c):
            fn = jax_kernel(kern, launch_dims=[n], num_outputs=1,
                            output_dims={'out': (n,)})
            return fn(a, b, c)

        return run

    # -----------------------------------------------------------------------
    # Step 2: Create and register the primitive
    # -----------------------------------------------------------------------
    scale_add_op = XLACustomKernel('tutorial_warp_scale_add')
    scale_add_op.def_warp_kernel(warp_scale_add_generator)

    print("Registered backends:", scale_add_op._kernels)
    print("Default backends  :", scale_add_op.defaults)
if WARP_AVAILABLE:
    # -----------------------------------------------------------------------
    # Step 3: Call the primitive
    # -----------------------------------------------------------------------
    N = 256
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.full(N, 2.0, dtype=jnp.float32)
    c = jnp.ones(N, dtype=jnp.float32)

    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)
    result   = scale_add_op(a, b, c, outs=[out_spec])

    expected = a * b + c
    print("Max error:", float(jnp.max(jnp.abs(result[0] - expected))))
    print("First 5  :", result[0][:5])

    # -----------------------------------------------------------------------
    # Step 4: Use inside jax.jit (the primitive is JIT-compatible)
    # -----------------------------------------------------------------------
    @jax.jit
    def jitted_op(a, b, c):
        return scale_add_op(a, b, c, outs=[jax.ShapeDtypeStruct(a.shape, a.dtype)])[0]

    r = jitted_op(a, b, c)
    print("JIT result matches:", bool(jnp.allclose(r, expected)))

8. Neuroscience Example: Sparse Synaptic Input Accumulation#

A classic operation in spiking neural network simulation: given a binary spike vector spikes (shape [N_pre]) and a CSR weight matrix (data, indices, indptr), compute the postsynaptic current

\[I_{\text{post}}[j] = \sum_{i:\, \text{spikes}[i]>0} W[\text{ptr}_{i}..\text{ptr}_{i+1}]\]

We implement this with a Warp kernel that:

  1. Iterates over pre-synaptic neurons in parallel

  2. Skips silent neurons (no spike)

  3. Atomically accumulates weights into the postsynaptic current buffer

if WARP_AVAILABLE:
    def csr_binary_mv_warp_generator(**kwargs):
        """
        Kernel generator for CSR × binary-vector multiplication.
        Signature: kernel(weights, indices, indptr, spikes) -> post_current
        """
        weight_info  = kwargs['weight_info']
        spike_info   = kwargs['spike_info']
        indices_info = kwargs['indices_info']
        indptr_info  = kwargs['indptr_info']
        n_pre        = indptr_info.shape[0] - 1
        n_post       = kwargs['n_post']
        out_dtype    = kwargs['outs'][0].dtype

        # Build Warp type descriptors
        w_type      = jaxinfo_to_warpinfo(weight_info)
        idx_type    = jaxinfo_to_warpinfo(indices_info)
        indptr_type = jaxinfo_to_warpinfo(indptr_info)
        spk_type    = jaxinfo_to_warpinfo(spike_info)
        out_type    = warp.array(dtype=jaxtype_to_warptype(out_dtype), ndim=1)

        @warp.kernel
        def mv_kern(
            weights: w_type,
            indices: idx_type,
            indptr:  indptr_type,
            spikes:  spk_type,
            posts:   out_type,
        ):
            i = warp.tid()              # one thread per pre-synaptic neuron
            if spikes[i]:               # skip silent neurons
                w = weights[0]          # scalar weight (homogeneous)
                for j in range(indptr[i], indptr[i + 1]):
                    warp.atomic_add(posts, indices[j], w)

        def kernel(weights, indices, indptr, spikes):
            fn = jax_kernel(
                mv_kern,
                launch_dims=[n_pre],
                num_outputs=1,
                in_out_argnames=['posts'],
            )
            return fn(weights, indices, indptr, spikes,
                      jnp.zeros(n_post, dtype=out_dtype))

        return kernel

    print("CSR binary MV kernel generator defined.")
if WARP_AVAILABLE:
    import scipy.sparse as sp

    # Build a random CSR connectivity matrix
    N_PRE  = 1000
    N_POST = 500
    PROB   = 0.05   # 5 % connection probability
    W      = 0.1    # homogeneous weight

    rng = np.random.default_rng(42)
    dense = (rng.random((N_PRE, N_POST)) < PROB).astype(np.float32) * W
    csr   = sp.csr_matrix(dense)

    data   = jnp.array([W], dtype=jnp.float32)          # scalar weight
    indices = jnp.array(csr.indices, dtype=jnp.int32)
    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)

    # Generate binary spikes (10 % firing rate)
    spikes = jnp.array(rng.random(N_PRE) < 0.10, dtype=jnp.bool_)

    # Register the primitive
    csr_mv_op = XLACustomKernel('tutorial_warp_csr_mv')
    csr_mv_op.def_warp_kernel(csr_binary_mv_warp_generator)

    # Build output spec and call
    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)

    result = csr_mv_op(
        data, indices, indptr, spikes,
        outs=[out_spec],
        # extra kwargs forwarded to the generator:
        weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),
        spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),
        indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
        indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),
        n_post       = N_POST,
    )

    # Reference: dense matmul
    spikes_f  = spikes.astype(jnp.float32)
    expected  = spikes_f @ jnp.array(dense)

    print(f"Network: {N_PRE} pre -> {N_POST} post  |  {int(spikes.sum())} spikes")
    print(f"Max error vs dense reference: {float(jnp.max(jnp.abs(result[0] - expected))):.6f}")
    print(f"Post current range: [{float(result[0].min()):.3f}, {float(result[0].max()):.3f}]")
if WARP_AVAILABLE:
    import time

    @jax.jit
    def warp_mv(data, indices, indptr, spikes):
        return csr_mv_op(
            data, indices, indptr, spikes,
            outs=[out_spec],
            weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),
            spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),
            indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
            indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),
            n_post       = N_POST,
        )[0]

    @jax.jit
    def dense_mv(spikes_f, dense):
        return spikes_f @ dense

    # Warm up
    jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
    jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))

    N_TRIALS = 200
    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
    warp_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))
    dense_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Warp sparse kernel : {warp_time:.3f} ms")
    print(f"JAX dense matmul   : {dense_time:.3f} ms")
    print(f"Speedup            : {dense_time / warp_time:.2f}x")

9. Multiple Backends with XLACustomKernel#

You can register multiple backends for the same operation and switch at runtime.

# CPU fallback using Numba (demonstrated here even if GPU is unavailable)
try:
    import numba
    from brainevent import numba_kernel
    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def _scale_add_numba(a, b, c, out):
        for i in numba.prange(out.size):
            out[i] = a[i] * b[i] + c[i]

    def numba_scale_add_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(a, b, c):
            return numba_kernel(_scale_add_numba, outs=out_info)(a, b, c)

        return kernel

    # Create op with both GPU (Warp) and CPU (Numba) backends
    multi_backend_op = XLACustomKernel('tutorial_multi_backend_scale_add')

    if WARP_AVAILABLE:
        multi_backend_op.def_warp_kernel(warp_scale_add_generator)   # GPU

    multi_backend_op.def_numba_kernel(numba_scale_add_generator)     # CPU

    print("Registered backends:", list(multi_backend_op._kernels.keys()))

    # On GPU, Warp is default; on CPU, Numba is used automatically
    N = 128
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.full(N, 3.0, dtype=jnp.float32)
    c = jnp.ones(N, dtype=jnp.float32)
    r = multi_backend_op(a, b, c, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])
    print("Result matches:", bool(jnp.allclose(r[0], a * b + c)))

10. Summary#

In this tutorial we covered:

  1. @warp.kernel – Write GPU kernels in Python-like syntax; use warp.tid() for the thread index.

  2. jax_kernel – Wrap a Warp kernel so JAX can call it with jax.Array inputs.

    • output_dims mode: Warp allocates the output buffer.

    • in_out_argnames mode: caller provides the initial buffer (needed for atomic accumulation).

  3. jaxinfo_to_warpinfo / jaxtype_to_warptype – Convert JAX dtype/shape info to Warp types for dynamic kernel construction inside kernel generators.

  4. XLACustomKernel.def_warp_kernel – Register a Warp kernel generator as the GPU backend of a multi-backend custom JAX primitive.

  5. Neuroscience application – Sparse CSR × binary-spike matrix-vector product implemented with Warp atomic operations, demonstrating the key pattern used throughout BrainEvent.

Next Steps#

  • Tutorial 7: Custom GPU operators with Numba CUDA (@cuda.jit)

  • Tutorial 8: Custom CPU operators with Numba (@numba.njit)

References#