{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# Custom C++ (CPU) kernels\n\nThis tutorial walks through writing a CPU kernel in plain C++ and calling it from JAX with `brainevent`. You write a standard C++ function; `brainevent` generates the XLA FFI wrapper, compiles it, caches the result, and registers it as a JAX FFI target.\n\n**Prerequisites:** a host C++ compiler (`g++` or `clang++`). No GPU is required.\n\nFor a task-focused summary see the how-to guide *Compile a raw CUDA/C++ kernel*; for the design, see *The custom-kernel architecture*."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import jax\nimport jax.numpy as jnp\nimport brainevent"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 1. The simplest kernel\n\nA kernel is just a C++ function that takes `BE::Tensor` arguments. The rule that drives everything: **`const BE::Tensor` is an input, non-`const` `BE::Tensor` is an output**. Include `\"brainevent/common.h\"` to get the `BE::Tensor` type; the internal FFI headers are injected for you."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "CPP_SRC = r\"\"\"\n#include \"brainevent/common.h\"\n\nvoid add_one(const BE::Tensor x, BE::Tensor y) {\n    int n = x.numel();\n    const float* in_ptr = static_cast<const float*>(x.data_ptr());\n    float* out_ptr = static_cast<float*>(y.data_ptr());\n    for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] + 1.0f;\n}\n\"\"\"\n\n# For CPU functions, pass a *list* of names: brainevent infers the arg_spec\n# from the signature (const -> input, non-const -> output).\nmod = brainevent.load_cpp_inline(\n    name=\"my_cpu_ops\",\n    cpp_sources=CPP_SRC,\n    functions=[\"add_one\"],\n)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "Now call it from JAX. The compiled function is registered under `\"<name>.<function>\"`, here `\"my_cpu_ops.add_one\"`. `ffi_call` needs the output shape and dtype."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "cpu = jax.devices(\"cpu\")[0]\nx = jax.device_put(jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32), cpu)\n\nresult = jax.ffi.ffi_call(\n    \"my_cpu_ops.add_one\",\n    jax.ShapeDtypeStruct(x.shape, x.dtype),\n    vmap_method=\"broadcast_all\",\n)(x)\n\nprint(result)  # [2. 3. 4.]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 2. Scalar attributes\n\nKernels often need scalar parameters (a learning rate, a scale factor). Declare them after the tensors and describe them with an `attr.<name>` token. Here we use the explicit dict form of `functions` to spell out the arg_spec.\n\nThe parameter order is fixed: **inputs, then outputs, then scalar attributes**."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "CPP_SRC = r\"\"\"\n#include \"brainevent/common.h\"\n\nvoid scale_by(const BE::Tensor x, BE::Tensor out, float factor) {\n    int n = x.numel();\n    const float* in_ptr = static_cast<const float*>(x.data_ptr());\n    float* out_ptr = static_cast<float*>(out.data_ptr());\n    for (int i = 0; i < n; ++i) out_ptr[i] = in_ptr[i] * factor;\n}\n\"\"\"\n\nmod = brainevent.load_cpp_inline(\n    name=\"scale_ops\",\n    cpp_sources=CPP_SRC,\n    functions={\"scale_by\": [\"arg\", \"ret\", \"attr.factor:float32\"]},\n)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "Scalar attributes are passed as **keyword arguments to the callable returned by** `ffi_call` -- not to `ffi_call` itself."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import numpy as np\n\nresult = jax.ffi.ffi_call(\n    \"scale_ops.scale_by\",\n    jax.ShapeDtypeStruct(x.shape, x.dtype),\n    vmap_method=\"broadcast_all\",\n)(x, factor=np.float32(10.0))\n\nprint(result)  # [10. 20. 30.]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 3. Loading from a file\n\nFor real projects, keep kernels in a `.cpp` file and load it instead of an inline string. `load_cpp_file` takes the same keyword arguments."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "# mod = brainevent.load_cpp_file(\n#     \"kernels/my_ops.cpp\",\n#     functions=[\"add_one\", \"scale_by\"],\n# )"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 4. Works inside `jax.jit`\n\nRegistered FFI targets compose with JAX transformations like any other primitive."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "@jax.jit\ndef add_one_jit(x):\n    return jax.ffi.ffi_call(\n        \"my_cpu_ops.add_one\",\n        jax.ShapeDtypeStruct(x.shape, x.dtype),\n        vmap_method=\"broadcast_all\",\n    )(x)\n\nprint(add_one_jit(x))"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## Where to next\n\n- The GPU counterpart: *Custom CUDA (GPU) kernels* (next tutorial).\n- `arg_spec` token grammar and the `BE::Tensor` API live in the Reference section.\n- Compiled artifacts are cached by a hash of the source and flags, so re-running this notebook recompiles nothing."
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
