{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# Custom CUDA (GPU) kernels\n\nThis tutorial mirrors the C++ tutorial but targets the GPU. You write a CUDA kernel plus a thin C++ launcher; `brainevent` compiles it with `nvcc`, generates the XLA FFI wrapper, caches the build, and registers it as a JAX FFI target.\n\n**Prerequisites:** an NVIDIA GPU + driver, `jax[cuda12]`/`jax[cuda13]` (which bundles `nvcc`), and a host C++ compiler. See *Installation*."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport brainevent"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 1. A vector-add kernel\n\nA CUDA entry point takes the tensors plus an `int64_t stream`. Two differences from the CPU case:\n\n- Always `#include <cuda_runtime.h>` yourself (it is **not** auto-injected).\n- The trailing `int64_t stream` argument carries the CUDA stream; cast it to `cudaStream_t` for the launch.\n\nThe `// @BE` comment annotates the entry point with its arg_spec, so `brainevent` discovers it automatically -- no `functions` dict needed."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "CUDA_SRC = r\"\"\"\n#include <cuda_runtime.h>\n#include \"brainevent/common.h\"\n\n__global__ void add_kernel(const float* a, const float* b, float* out, int n) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) out[idx] = a[idx] + b[idx];\n}\n\n// @BE vector_add arg arg ret stream\nvoid vector_add(const BE::Tensor a, const BE::Tensor b,\n                BE::Tensor out, int64_t stream) {\n    int n = a.numel();\n    add_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(\n        static_cast<const float*>(a.data_ptr()),\n        static_cast<const float*>(b.data_ptr()),\n        static_cast<float*>(out.data_ptr()), n);\n}\n\"\"\"\n\nmod = brainevent.load_cuda_inline(\n    name=\"my_kernels\",\n    cuda_sources=CUDA_SRC,\n    functions=None,  # discovered from the // @BE annotation\n)"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "a = jnp.ones(1024, dtype=jnp.float32)\nb = jnp.full(1024, 2.0, dtype=jnp.float32)\n\nresult = jax.ffi.ffi_call(\n    \"my_kernels.vector_add\",\n    jax.ShapeDtypeStruct(a.shape, a.dtype),\n)(a, b)\n\nprint(result[:5])  # [3. 3. 3. 3. 3.]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 2. Scalar attributes\n\nAdd scalar parameters with `attr.<name>` tokens, after the tensors and before the stream. The explicit `:float32` suffix pins the type; the bare `attr.scale_factor` form would infer it from the C++ signature."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "CUDA_SRC = r\"\"\"\n#include <cuda_runtime.h>\n#include \"brainevent/common.h\"\n\n__global__ void scale_kernel(const float* x, float* out, int n, float factor) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) out[idx] = x[idx] * factor;\n}\n\n// @BE scale_by arg ret attr.scale_factor:float32 stream\nvoid scale_by(const BE::Tensor x, BE::Tensor out,\n              float scale_factor, int64_t stream) {\n    int n = x.numel();\n    scale_kernel<<<(n + 255) / 256, 256, 0, (cudaStream_t)stream>>>(\n        static_cast<const float*>(x.data_ptr()),\n        static_cast<float*>(out.data_ptr()), n, scale_factor);\n    BE_CHECK_KERNEL_LAUNCH();\n}\n\"\"\"\n\nmod = brainevent.load_cuda_inline(name=\"scale_ops\", cuda_sources=CUDA_SRC)"
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "result = jax.ffi.ffi_call(\n    \"scale_ops.scale_by\",\n    jax.ShapeDtypeStruct(a.shape, a.dtype),\n)(a, scale_factor=np.float32(3.0))\n\nprint(result[:5])  # [3. 3. 3. 3. 3.]"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 3. Compiler options\n\n`load_cuda_inline` forwards optimization flags to `nvcc`. `use_fast_math=True` trades a few ULPs of precision for ~10-30% on FP-heavy kernels; `optimization_level` maps to `-O<n>` (default 3). These flags are part of the cache key, so changing them triggers a clean recompile."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "mod = brainevent.load_cuda_inline(\n    name=\"fast_kernels\",\n    cuda_sources=CUDA_SRC,\n    optimization_level=3,\n    use_fast_math=True,\n)"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 4. Loading from a file & `jax.jit`\n\nKeep kernels in `.cu` files for real projects, and call them inside jitted functions."
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": "# mod = brainevent.load_cuda_file(\"kernels/my_kernel.cu\")\n\n@jax.jit\ndef add_jit(a, b):\n    return jax.ffi.ffi_call(\n        \"my_kernels.vector_add\",\n        jax.ShapeDtypeStruct(a.shape, a.dtype),\n    )(a, b)\n\nprint(add_jit(a, b)[:5])"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## Where to next\n\n- The full `arg_spec` grammar, the `BE::Tensor` C++ API, the compiler options, and the caching model are all in the Reference section under *Custom kernels*.\n- Prefer a higher-level path? The Numba, Numba-CUDA, and Warp tutorials need no separate compiler."
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
