Custom CUDA (GPU) kernels#

This tutorial mirrors the C++ tutorial but targets the GPU. You write a CUDA kernel plus a thin C++ launcher; brainevent compiles it with nvcc, generates the XLA FFI wrapper, caches the build, and registers it as a JAX FFI target.

Prerequisites: an NVIDIA GPU + driver, jax[cuda12]/jax[cuda13] (which bundles nvcc), and a host C++ compiler. See Installation.

import jax
import jax.numpy as jnp
import numpy as np
import brainevent

1. A vector-add kernel#

A CUDA entry point takes the tensors plus an int64_t stream. Two differences from the CPU case:

  • Always #include <cuda_runtime.h> yourself (it is not auto-injected).

  • The trailing int64_t stream argument carries the CUDA stream; cast it to cudaStream_t for the launch.

The // @BE comment annotates the entry point with its arg_spec, so brainevent discovers it automatically – no functions dict needed.

CUDA_SRC = r"""
#include <cuda_runtime.h>
#include "brainevent/common.h"

__global__ void add_kernel(const float* a, const float* b, float* out, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) out[idx] = a[idx] + b[idx];
}

// @BE vector_add arg arg ret stream
void vector_add(const BE::Tensor a, const BE::Tensor b,
                BE::Tensor out, int64_t stream) {
    int n = a.numel();
    add_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(
        static_cast<const float*>(a.data_ptr()),
        static_cast<const float*>(b.data_ptr()),
        static_cast<float*>(out.data_ptr()), n);
}
"""

mod = brainevent.load_cuda_inline(
    name="my_kernels",
    cuda_sources=CUDA_SRC,
    functions=None,  # discovered from the // @BE annotation
)
a = jnp.ones(1024, dtype=jnp.float32)
b = jnp.full(1024, 2.0, dtype=jnp.float32)

result = jax.ffi.ffi_call(
    "my_kernels.vector_add",
    jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, b)

print(result[:5])  # [3. 3. 3. 3. 3.]

2. Scalar attributes#

Add scalar parameters with attr.<name> tokens, after the tensors and before the stream. The explicit :float32 suffix pins the type; the bare attr.scale_factor form would infer it from the C++ signature.

CUDA_SRC = r"""
#include <cuda_runtime.h>
#include "brainevent/common.h"

__global__ void scale_kernel(const float* x, float* out, int n, float factor) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) out[idx] = x[idx] * factor;
}

// @BE scale_by arg ret attr.scale_factor:float32 stream
void scale_by(const BE::Tensor x, BE::Tensor out,
              float scale_factor, int64_t stream) {
    int n = x.numel();
    scale_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(
        static_cast<const float*>(x.data_ptr()),
        static_cast<float*>(out.data_ptr()), n, scale_factor);
    BE_CHECK_KERNEL_LAUNCH();
}
"""

mod = brainevent.load_cuda_inline(name="scale_ops", cuda_sources=CUDA_SRC)
result = jax.ffi.ffi_call(
    "scale_ops.scale_by",
    jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, scale_factor=np.float32(3.0))

print(result[:5])  # [3. 3. 3. 3. 3.]

3. Compiler options#

load_cuda_inline forwards optimization flags to nvcc. use_fast_math=True trades a few ULPs of precision for ~10-30% on FP-heavy kernels; optimization_level maps to -O<n> (default 3). These flags are part of the cache key, so changing them triggers a clean recompile.

mod = brainevent.load_cuda_inline(
    name="fast_kernels",
    cuda_sources=CUDA_SRC,
    optimization_level=3,
    use_fast_math=True,
)

4. Loading from a file & jax.jit#

Keep kernels in .cu files for real projects, and call them inside jitted functions.

# mod = brainevent.load_cuda_file("kernels/my_kernel.cu")

@jax.jit
def add_jit(a, b):
    return jax.ffi.ffi_call(
        "my_kernels.vector_add",
        jax.ShapeDtypeStruct(a.shape, a.dtype),
    )(a, b)

print(add_jit(a, b)[:5])

Where to next#

  • The full arg_spec grammar, the BE::Tensor C++ API, the compiler options, and the caching model are all in the Reference section under Custom kernels.

  • Prefer a higher-level path? The Numba, Numba-CUDA, and Warp tutorials need no separate compiler.