Custom CUDA (GPU) kernels#
This tutorial mirrors the C++ tutorial but targets the GPU. You write a CUDA kernel plus a thin C++ launcher; brainevent compiles it with nvcc, generates the XLA FFI wrapper, caches the build, and registers it as a JAX FFI target.
Prerequisites: an NVIDIA GPU + driver, jax[cuda12]/jax[cuda13] (which bundles nvcc), and a host C++ compiler. See Installation.
import jax
import jax.numpy as jnp
import numpy as np
import brainevent
1. A vector-add kernel#
A CUDA entry point takes the tensors plus an int64_t stream. Two differences from the CPU case:
Always
#include <cuda_runtime.h>yourself (it is not auto-injected).The trailing
int64_t streamargument carries the CUDA stream; cast it tocudaStream_tfor the launch.
The // @BE comment annotates the entry point with its arg_spec, so brainevent discovers it automatically – no functions dict needed.
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include "brainevent/common.h"
__global__ void add_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = a[idx] + b[idx];
}
// @BE vector_add arg arg ret stream
void vector_add(const BE::Tensor a, const BE::Tensor b,
BE::Tensor out, int64_t stream) {
int n = a.numel();
add_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(
static_cast<const float*>(a.data_ptr()),
static_cast<const float*>(b.data_ptr()),
static_cast<float*>(out.data_ptr()), n);
}
"""
mod = brainevent.load_cuda_inline(
name="my_kernels",
cuda_sources=CUDA_SRC,
functions=None, # discovered from the // @BE annotation
)
a = jnp.ones(1024, dtype=jnp.float32)
b = jnp.full(1024, 2.0, dtype=jnp.float32)
result = jax.ffi.ffi_call(
"my_kernels.vector_add",
jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, b)
print(result[:5]) # [3. 3. 3. 3. 3.]
2. Scalar attributes#
Add scalar parameters with attr.<name> tokens, after the tensors and before the stream. The explicit :float32 suffix pins the type; the bare attr.scale_factor form would infer it from the C++ signature.
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include "brainevent/common.h"
__global__ void scale_kernel(const float* x, float* out, int n, float factor) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = x[idx] * factor;
}
// @BE scale_by arg ret attr.scale_factor:float32 stream
void scale_by(const BE::Tensor x, BE::Tensor out,
float scale_factor, int64_t stream) {
int n = x.numel();
scale_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(
static_cast<const float*>(x.data_ptr()),
static_cast<float*>(out.data_ptr()), n, scale_factor);
BE_CHECK_KERNEL_LAUNCH();
}
"""
mod = brainevent.load_cuda_inline(name="scale_ops", cuda_sources=CUDA_SRC)
result = jax.ffi.ffi_call(
"scale_ops.scale_by",
jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, scale_factor=np.float32(3.0))
print(result[:5]) # [3. 3. 3. 3. 3.]
3. Compiler options#
load_cuda_inline forwards optimization flags to nvcc. use_fast_math=True trades a few ULPs of precision for ~10-30% on FP-heavy kernels; optimization_level maps to -O<n> (default 3). These flags are part of the cache key, so changing them triggers a clean recompile.
mod = brainevent.load_cuda_inline(
name="fast_kernels",
cuda_sources=CUDA_SRC,
optimization_level=3,
use_fast_math=True,
)
4. Loading from a file & jax.jit#
Keep kernels in .cu files for real projects, and call them inside jitted functions.
# mod = brainevent.load_cuda_file("kernels/my_kernel.cu")
@jax.jit
def add_jit(a, b):
return jax.ffi.ffi_call(
"my_kernels.vector_add",
jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, b)
print(add_jit(a, b)[:5])
Where to next#
The full
arg_specgrammar, theBE::TensorC++ API, the compiler options, and the caching model are all in the Reference section under Custom kernels.Prefer a higher-level path? The Numba, Numba-CUDA, and Warp tutorials need no separate compiler.