Custom C++ (CPU) kernels#

This tutorial walks through writing a CPU kernel in plain C++ and calling it from JAX with brainevent. You write a standard C++ function; brainevent generates the XLA FFI wrapper, compiles it, caches the result, and registers it as a JAX FFI target.

Prerequisites: a host C++ compiler (g++ or clang++). No GPU is required.

For a task-focused summary see the how-to guide Compile a raw CUDA/C++ kernel; for the design, see The custom-kernel architecture.

import jax
import jax.numpy as jnp
import brainevent

1. The simplest kernel#

A kernel is just a C++ function that takes BE::Tensor arguments. The rule that drives everything: const BE::Tensor is an input, non-const BE::Tensor is an output. Include "brainevent/common.h" to get the BE::Tensor type; the internal FFI headers are injected for you.

CPP_SRC = r"""
#include "brainevent/common.h"

void add_one(const BE::Tensor x, BE::Tensor y) {
    int n = x.numel();
    const float* in_ptr = static_cast<const float*>(x.data_ptr());
    float* out_ptr = static_cast<float*>(y.data_ptr());
    for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] + 1.0f;
}
"""

# For CPU functions, pass a *list* of names: brainevent infers the arg_spec
# from the signature (const -> input, non-const -> output).
mod = brainevent.load_cpp_inline(
    name="my_cpu_ops",
    cpp_sources=CPP_SRC,
    functions=["add_one"],
)

Now call it from JAX. The compiled function is registered under "<name>.<function>", here "my_cpu_ops.add_one". ffi_call needs the output shape and dtype.

cpu = jax.devices("cpu")[0]
x = jax.device_put(jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), cpu)

result = jax.ffi.ffi_call(
    "my_cpu_ops.add_one",
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    vmap_method="broadcast_all",
)(x)

print(result)  # [2. 3. 4.]

2. Scalar attributes#

Kernels often need scalar parameters (a learning rate, a scale factor). Declare them after the tensors and describe them with an attr.<name> token. Here we use the explicit dict form of functions to spell out the arg_spec.

The parameter order is fixed: inputs, then outputs, then scalar attributes.

CPP_SRC = r"""
#include "brainevent/common.h"

void scale_by(const BE::Tensor x, BE::Tensor out, float factor) {
    int n = x.numel();
    const float* in_ptr = static_cast<const float*>(x.data_ptr());
    float* out_ptr = static_cast<float*>(out.data_ptr());
    for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] * factor;
}
"""

mod = brainevent.load_cpp_inline(
    name="scale_ops",
    cpp_sources=CPP_SRC,
    functions={"scale_by": ["arg", "ret", "attr.factor:float32"]},
)

Scalar attributes are passed as keyword arguments to the callable returned by ffi_call – not to ffi_call itself.

import numpy as np

result = jax.ffi.ffi_call(
    "scale_ops.scale_by",
    jax.ShapeDtypeStruct(x.shape, x.dtype),
    vmap_method="broadcast_all",
)(x, factor=np.float32(10.0))

print(result)  # [10. 20. 30.]

3. Loading from a file#

For real projects, keep kernels in a .cpp file and load it instead of an inline string. load_cpp_file takes the same keyword arguments.

# mod = brainevent.load_cpp_file(
#     "kernels/my_ops.cpp",
#     functions=["add_one", "scale_by"],
# )

4. Works inside jax.jit#

Registered FFI targets compose with JAX transformations like any other primitive.

@jax.jit
def add_one_jit(x):
    return jax.ffi.ffi_call(
        "my_cpu_ops.add_one",
        jax.ShapeDtypeStruct(x.shape, x.dtype),
        vmap_method="broadcast_all",
    )(x)

print(add_one_jit(x))

Where to next#

  • The GPU counterpart: Custom CUDA (GPU) kernels (next tutorial).

  • arg_spec token grammar and the BE::Tensor API live in the Reference section.

  • Compiled artifacts are cached by a hash of the source and flags, so re-running this notebook recompiles nothing.