Custom Operators with Raw CUDA/C++

Contents

Custom Operators with Raw CUDA/C++#

Seamless bridge between native C++/CUDA kernels and JAX via XLA FFI.

Write kernels in plain C++ (CPU) or CUDA (GPU), and call them from JAX with zero boilerplate. brainevent handles compilation, XLA FFI wrapper generation, caching, and registration automatically.

import brainevent

mod = brainevent.load_cuda_inline(
    name="my_kernels",
    cuda_sources=CUDA_SRC,
    functions={"vector_add": ["arg", "arg", "ret", "stream"]},
)

# Call from JAX
result = jax.ffi.ffi_call("my_kernels.vector_add", out_spec)(a, b)

Key Features#

  • Zero boilerplate — write standard CUDA/C++ and call it from JAX

  • Automatic FFI wrapper generation — no manual XLA FFI binding code

  • Multi-platform — CUDA (GPU) and C++ (CPU)

  • Smart caching — SHA-256-based compilation cache, survives process restarts

  • Ninja parallel builds — fast multi-file compilation when ninja is available

  • Auto-registration — compiled functions are automatically registered as JAX FFI targets; re-importing the same module is a no-op