arg_spec System#
Every function compiled by brainevent needs an arg_spec — a
list of tokens that describes the function’s parameter types. This tells
brainevent how to generate the XLA FFI wrapper code.
Token Reference#
Token |
Description |
C++ Parameter Type |
|---|---|---|
|
Input tensor (read-only) |
|
|
Output tensor (pre-allocated by XLA) |
|
|
CUDA stream handle |
|
|
Scalar attribute — type auto-inferred from C++ signature |
Depends on C++ parameter type (see table below) |
|
Scalar attribute — type explicit in the token |
Depends on |
kernix Compatible Aliases#
kernix also accepts tokens from the kernix naming convention. They are transparently normalised before parsing, so both styles can be mixed freely.
kernix token |
Normalised to |
Notes |
|---|---|---|
|
|
Plural alias for a single input tensor |
|
|
Plural alias for a single output tensor |
|
|
Alternative CUDA stream token |
|
|
Bare attribute (type still inferred from C++ signature) |
Attribute Types#
The attr token supports these types. When using the bare
"attr.<name>" form, kernix infers the type by parsing the C++ function
signature. The explicit "attr.<name>:<type>" form is also accepted and
takes precedence.
Type String |
C++ Type (impl param) |
Python Value |
XLA FFI scalar |
|---|---|---|---|
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
✓ native |
|
|
|
⚠ via uint16 |
|
|
|
⚠ via uint16 |
Note
float16 / bfloat16 attrs: XLA FFI has no native scalar attr decoding
for 16-bit float types. kernix maps them to uint16_t; the C++ function
receives the raw bit pattern and must reinterpret internally (e.g.
__half h = *reinterpret_cast<const __half*>(&bits);). At call time,
pass the raw bits: scale=numpy.float16(1.5).view(numpy.uint16).
For most ML use cases, float32 is preferable.
CUDA Header Convention#
kernix does not auto-inject <cuda_runtime.h> into your kernel source.
Always add it explicitly at the top so that cudaStream_t and other CUDA
runtime types are in scope:
#include <cuda_runtime.h> // ← required for cudaStream_t, etc.
#include "brainevent/common.h" // BE::Tensor, BE_CUDA_CHECK, ...
// @BE my_kernel arg ret stream
void my_kernel(const BE::Tensor x, BE::Tensor out, int64_t stream) {
auto s = (cudaStream_t)stream;
// ...
}
Examples#
Basic CUDA kernel (two inputs, one output, stream):
functions={"vector_add": ["arg", "arg", "ret", "stream"]}
Kernel with a scalar attribute — bare form (type inferred from C++):
functions={"scale_by": ["arg", "ret", "attr.scale_factor", "stream"]}
Kernel with a scalar attribute — explicit form:
functions={"scale_by": ["arg", "ret", "attr.scale_factor:float32", "stream"]}
Multiple outputs:
functions={"split": ["arg", "ret", "ret", "stream"]}
Multiple attributes:
functions={"scale_add": ["arg", "ret", "attr.scale", "attr.offset", "stream"]}
kernix style (equivalent to the above):
# These are all identical after normalisation:
functions={"vector_add": ["args", "args", "rets", "ctx.stream"]}
functions={"scale_by": ["args", "rets", "attrs.scale_factor", "ctx.stream"]}
Function Signature Convention#
Your C++ function parameters must follow this order:
Input tensors (
"arg"tokens) asconst BE::TensorOutput tensors (
"ret"tokens) asBE::TensorScalar attributes (
"attr.*"tokens) as the corresponding C++ typeCUDA stream (
"stream"token) asint64_t
// Matches: ["arg", "ret", "attr.scale", "stream"]
void my_kernel(const BE::Tensor input,
BE::Tensor output,
float scale,
int64_t stream);
Warning
``const`` on BE::Tensor is the only thing that distinguishes an
input from an output in auto-detection.
const BE::Tensor freezes only the tensor metadata (shape, dtype) — not
the underlying GPU data. C++ silently allows
static_cast<float*>(param.data_ptr()) even on a const BE::Tensor, so
there is no compiler warning when an output tensor is accidentally marked
const.
If every BE::Tensor parameter is const, auto-detection raises:
KernelError: No non-const Tensor output found in 'my_func'.
Mark input Tensors with 'const' to distinguish inputs from outputs,
or use the explicit dict form: functions={'my_func': ['arg', 'ret', ...]}
Rule: remove const from every tensor the kernel writes to, regardless
of whether C++ requires it.
Attribute Type Inference#
When you write "attr.name" without a type suffix, kernix parses the C++
function signature to determine the type automatically. The following C++
parameter types are recognised:
C++ Parameter Type |
Inferred attr type |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Leading const qualifiers are stripped before lookup. Pointer types,
__half, __nv_bfloat16, and other non-standard types are not
auto-inferred — use the explicit "attr.name:<type>" form instead.
Auto-Detection (CPU)#
For CPU functions, you can pass a list of function names instead of a dict. kernix will parse the C++ signatures and infer the full arg_spec automatically:
const BE::Tensorparameters become"arg"Non-const
BE::Tensorparameters become"ret"Scalar parameters become
"attr.<name>:<type>"
# These are equivalent:
functions=["add_one"]
functions={"add_one": ["arg", "ret"]}
// Auto-detected as ["arg", "ret"]
void add_one(const BE::Tensor x, BE::Tensor y);
Passing Attributes at Call Time#
Scalar attributes are passed as keyword arguments to the returned callable
from jax.ffi.ffi_call():
# CORRECT
jax.ffi.ffi_call("target", spec)(x, scale_factor=np.float32(3.0))
# WRONG -- attributes must NOT go to ffi_call() directly
jax.ffi.ffi_call("target", spec, scale_factor=...)(x)