Quick Start#
CUDA (GPU)#
Write a CUDA kernel, compile it, and call it from JAX:
import jax
import jax.numpy as jnp
import brainevent
CUDA_SRC = r"""
#include <cuda_runtime.h>
#include "brainevent/common.h"
__global__ void add_kernel(const float* a, const float* b, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = a[idx] + b[idx];
}
// @BE vector_add arg arg ret stream
void vector_add(const BE::Tensor a, const BE::Tensor b,
BE::Tensor out, int64_t stream) {
int n = a.numel();
add_kernel<<<(n+255)/256, 256, 0, (cudaStream_t)stream>>>(
static_cast<const float*>(a.data_ptr()),
static_cast<const float*>(b.data_ptr()),
static_cast<float*>(out.data_ptr()), n);
}
"""
# Compile and register in one call
mod = brainevent.load_cuda_inline(
name="my_kernels",
cuda_sources=CUDA_SRC,
functions={"vector_add": ["arg", "arg", "ret", "stream"]},
)
# Call from JAX
a = jnp.ones(1024, dtype=jnp.float32)
b = jnp.full(1024, 2.0, dtype=jnp.float32)
result = jax.ffi.ffi_call(
"my_kernels.vector_add",
jax.ShapeDtypeStruct(a.shape, a.dtype),
)(a, b)
print(result) # [3. 3. 3. ... 3.]
The functions dict maps each function name to its arg_spec
token list. brainevent auto-generates the XLA FFI wrapper and registers the
function as "my_kernels.vector_add".
Tip
Instead of a functions dict you can annotate entry points directly in
the CUDA source with // @BE:
// @BE vector_add arg arg ret stream
void vector_add(const BE::Tensor a, const BE::Tensor b,
BE::Tensor out, int64_t stream) { ... }
Then pass functions=None (the default) and brainevent discovers them
automatically.
CPU (C++)#
CPU kernels work the same way but use load_cpp_inline and don’t need CUDA:
import jax
import jax.numpy as jnp
import brainevent
CPP_SRC = r"""
#include "brainevent/common.h"
void add_one(const BE::Tensor x, BE::Tensor y) {
int n = x.numel();
const float* in_ptr = static_cast<const float*>(x.data_ptr());
float* out_ptr = static_cast<float*>(y.data_ptr());
for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] + 1.0f;
}
"""
# Auto-detects arg_spec from C++ signature (const -> arg, non-const -> ret)
mod = brainevent.load_cpp_inline(
name="my_cpu_ops",
cpp_sources=CPP_SRC,
functions=["add_one"], # list form: auto-detect arg_spec
)
cpu = jax.devices("cpu")[0]
x = jax.device_put(jnp.array([1.0, 2.0, 3.0]), cpu)
result = jax.ffi.ffi_call(
"my_cpu_ops.add_one",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="broadcast_all",
)(x)
print(result) # [2. 3. 4.]
For CPU functions you can pass a list of function names instead of a dict.
brainevent will parse the C++ signatures automatically: const BE::Tensor
parameters become "arg" tokens and non-const BE::Tensor parameters
become "ret" tokens.
Using @jax.jit#
All registered FFI targets work seamlessly inside @jax.jit:
@jax.jit
def add_jit(x, y):
return jax.ffi.ffi_call(
"my_kernels.vector_add",
jax.ShapeDtypeStruct(x.shape, x.dtype),
)(x, y)
result = add_jit(a, b)
Loading from Files#
Instead of inline source strings, compile directly from files on disk:
# Single file — name defaults to the file stem
mod = brainevent.load_cuda_file("kernels/my_kernel.cu")
# Explicit functions dict if not using // @BE annotations
mod = brainevent.load_cuda_file(
"kernels/my_kernel.cu",
functions={"my_func": ["arg", "ret", "stream"]},
)
# Entire directory (uses ninja for parallel compilation when available)
mod = brainevent.load_cuda_dir(
"kernels/",
functions={"func_a": ["arg", "ret", "stream"],
"func_b": ["arg", "arg", "ret", "stream"]},
)