Custom C++ (CPU) kernels#
This tutorial walks through writing a CPU kernel in plain C++ and calling it from JAX with brainevent. You write a standard C++ function; brainevent generates the XLA FFI wrapper, compiles it, caches the result, and registers it as a JAX FFI target.
Prerequisites: a host C++ compiler (g++ or clang++). No GPU is required.
For a task-focused summary see the how-to guide Compile a raw CUDA/C++ kernel; for the design, see The custom-kernel architecture.
import jax
import jax.numpy as jnp
import brainevent
1. The simplest kernel#
A kernel is just a C++ function that takes BE::Tensor arguments. The rule that drives everything: const BE::Tensor is an input, non-const BE::Tensor is an output. Include "brainevent/common.h" to get the BE::Tensor type; the internal FFI headers are injected for you.
CPP_SRC = r"""
#include "brainevent/common.h"
void add_one(const BE::Tensor x, BE::Tensor y) {
int n = x.numel();
const float* in_ptr = static_cast<const float*>(x.data_ptr());
float* out_ptr = static_cast<float*>(y.data_ptr());
for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] + 1.0f;
}
"""
# For CPU functions, pass a *list* of names: brainevent infers the arg_spec
# from the signature (const -> input, non-const -> output).
mod = brainevent.load_cpp_inline(
name="my_cpu_ops",
cpp_sources=CPP_SRC,
functions=["add_one"],
)
Now call it from JAX. The compiled function is registered under "<name>.<function>", here "my_cpu_ops.add_one". ffi_call needs the output shape and dtype.
cpu = jax.devices("cpu")[0]
x = jax.device_put(jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), cpu)
result = jax.ffi.ffi_call(
"my_cpu_ops.add_one",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="broadcast_all",
)(x)
print(result) # [2. 3. 4.]
2. Scalar attributes#
Kernels often need scalar parameters (a learning rate, a scale factor). Declare them after the tensors and describe them with an attr.<name> token. Here we use the explicit dict form of functions to spell out the arg_spec.
The parameter order is fixed: inputs, then outputs, then scalar attributes.
CPP_SRC = r"""
#include "brainevent/common.h"
void scale_by(const BE::Tensor x, BE::Tensor out, float factor) {
int n = x.numel();
const float* in_ptr = static_cast<const float*>(x.data_ptr());
float* out_ptr = static_cast<float*>(out.data_ptr());
for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] * factor;
}
"""
mod = brainevent.load_cpp_inline(
name="scale_ops",
cpp_sources=CPP_SRC,
functions={"scale_by": ["arg", "ret", "attr.factor:float32"]},
)
Scalar attributes are passed as keyword arguments to the callable returned by ffi_call – not to ffi_call itself.
import numpy as np
result = jax.ffi.ffi_call(
"scale_ops.scale_by",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="broadcast_all",
)(x, factor=np.float32(10.0))
print(result) # [10. 20. 30.]
3. Loading from a file#
For real projects, keep kernels in a .cpp file and load it instead of an inline string. load_cpp_file takes the same keyword arguments.
# mod = brainevent.load_cpp_file(
# "kernels/my_ops.cpp",
# functions=["add_one", "scale_by"],
# )
4. Works inside jax.jit#
Registered FFI targets compose with JAX transformations like any other primitive.
@jax.jit
def add_one_jit(x):
return jax.ffi.ffi_call(
"my_cpu_ops.add_one",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="broadcast_all",
)(x)
print(add_one_jit(x))
Where to next#
The GPU counterpart: Custom CUDA (GPU) kernels (next tutorial).
arg_spectoken grammar and theBE::TensorAPI live in the Reference section.Compiled artifacts are cached by a hash of the source and flags, so re-running this notebook recompiles nothing.