{
 "nbformat": 4,
 "nbformat_minor": 4,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom GPU Operators with Numba CUDA\n",
    "\n",
    "This tutorial shows how to write custom GPU kernels using **Numba CUDA** and integrate them\n",
    "into the BrainEvent / JAX ecosystem.\n",
    "\n",
    "[Numba](https://numba.readthedocs.io/) is a JIT compiler for Python that targets CUDA GPUs\n",
    "via its `numba.cuda` subpackage. Kernels are written in Python, compiled at first call,\n",
    "and run natively on the GPU. BrainEvent provides two functions that bridge Numba CUDA kernels\n",
    "into JAX via XLA's Foreign Function Interface (FFI):\n",
    "\n",
    "- **`numba_cuda_kernel`** – wraps a single `@cuda.jit` kernel with fixed launch configuration.\n",
    "- **`numba_cuda_callable`** – wraps an arbitrary Python function that may launch *multiple*\n",
    "  CUDA kernels, allocate temporary memory, and orchestrate multi-step GPU computation.\n",
    "\n",
    "## Contents\n",
    "1. Why Numba CUDA?\n",
    "2. Installation and Imports\n",
    "3. Writing Numba CUDA Kernels (`@cuda.jit`)\n",
    "4. `numba_cuda_kernel` – Single-Kernel Wrapper\n",
    "5. Launch Configuration: `grid` / `block` vs. `launch_dims`\n",
    "6. `numba_cuda_callable` – Multi-Kernel Wrapper\n",
    "7. Registering with `XLACustomKernel`\n",
    "8. Neuroscience Example: Parallel Spike Threshold Detection\n",
    "9. Performance Tips\n",
    "10. Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Why Numba CUDA?\n",
    "\n",
    "| Feature | Numba CUDA | Warp | Raw CUDA C++ |\n",
    "|---------|------------|------|--------------|\n",
    "| Language | Python (`@cuda.jit`) | Python-like | C++ |\n",
    "| Low-level control | Full thread/block/shared-mem | Partial | Full |\n",
    "| Shared memory | Yes | Yes | Yes |\n",
    "| Atomic operations | Yes | Yes | Yes |\n",
    "| Device-side allocation | Yes (`cuda.device_array`) | Limited | Yes |\n",
    "| Multi-kernel orchestration | Yes (via `numba_cuda_callable`) | No | Yes |\n",
    "\n",
    "Choose Numba CUDA when you need:\n",
    "- Fine-grained control over shared memory or thread synchronization\n",
    "- Multi-kernel pipelines with temporary device allocations\n",
    "- Familiar CUDA programming model in Python\n",
    "\n",
    "**Requirements:**\n",
    "- NVIDIA GPU with CUDA\n",
    "- `pip install numba` + CUDA toolkit\n",
    "- JAX with GPU support (`pip install jax[cuda12]`)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Installation and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install if needed:\n",
    "# !pip install numba -U\n",
    "# !pip install brainevent[cuda12] -U\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import brainevent\n",
    "from brainevent import XLACustomKernel, numba_cuda_kernel, numba_cuda_callable\n",
    "\n",
    "print(f\"JAX version    : {jax.__version__}\")\n",
    "print(f\"JAX backend    : {jax.default_backend()}\")\n",
    "print(f\"BrainEvent     : {brainevent.__version__}\")\n",
    "\n",
    "try:\n",
    "    from numba import cuda\n",
    "    NUMBA_CUDA_AVAILABLE = cuda.is_available()\n",
    "    if NUMBA_CUDA_AVAILABLE:\n",
    "        import numba\n",
    "        print(f\"Numba version  : {numba.__version__}\")\n",
    "        print(f\"CUDA available : {NUMBA_CUDA_AVAILABLE}\")\n",
    "        gpu = cuda.get_current_device()\n",
    "        print(f\"GPU            : {gpu.name}\")\n",
    "    else:\n",
    "        print(\"Numba installed but CUDA device not found.\")\n",
    "except ImportError:\n",
    "    print(\"Numba not installed. Run: pip install numba\")\n",
    "    NUMBA_CUDA_AVAILABLE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Writing Numba CUDA Kernels (`@cuda.jit`)\n",
    "\n",
    "Numba CUDA kernels follow standard CUDA programming conventions:\n",
    "- Decorated with `@cuda.jit`\n",
    "- Each thread identifies itself via `cuda.grid(ndim)` (equivalent to `blockIdx * blockDim + threadIdx`)\n",
    "- Arrays received as Numba device arrays (zero-copy from JAX GPU memory)\n",
    "- Results are written **in-place** into output arrays (no return values)\n",
    "\n",
    "### 3.1 Basic Element-wise Kernel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    @cuda.jit\n",
    "    def elementwise_relu_kernel(x, out):\n",
    "        \"\"\"Element-wise ReLU: out[i] = max(x[i], 0).\"\"\"\n",
    "        i = cuda.grid(1)          # global thread index\n",
    "        if i < out.size:          # bounds check\n",
    "            out[i] = max(x[i], 0.0)\n",
    "\n",
    "    @cuda.jit\n",
    "    def elementwise_sigmoid_kernel(x, out):\n",
    "        \"\"\"Element-wise sigmoid: out[i] = 1 / (1 + exp(-x[i])).\"\"\"\n",
    "        import math\n",
    "        i = cuda.grid(1)\n",
    "        if i < out.size:\n",
    "            out[i] = 1.0 / (1.0 + math.exp(-x[i]))\n",
    "\n",
    "    print(\"Kernels defined:\", elementwise_relu_kernel, elementwise_sigmoid_kernel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Kernel with Shared Memory\n",
    "\n",
    "Shared memory is fast, on-chip memory shared between threads in the same block.\n",
    "It is useful for reduction operations or when threads need to communicate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    BLOCK_SIZE = 256\n",
    "\n",
    "    @cuda.jit\n",
    "    def block_sum_kernel(x, block_sums):\n",
    "        \"\"\"\n",
    "        Computes the sum of each block's elements using shared memory reduction.\n",
    "        block_sums[blockIdx.x] = sum of x elements processed by block blockIdx.x.\n",
    "        \"\"\"\n",
    "        shared = cuda.shared.array(BLOCK_SIZE, dtype=numba.float32)\n",
    "\n",
    "        tx  = cuda.threadIdx.x\n",
    "        pos = cuda.grid(1)\n",
    "\n",
    "        # Load into shared memory\n",
    "        shared[tx] = x[pos] if pos < x.size else 0.0\n",
    "        cuda.syncthreads()\n",
    "\n",
    "        # Parallel reduction in shared memory\n",
    "        stride = BLOCK_SIZE // 2\n",
    "        while stride > 0:\n",
    "            if tx < stride:\n",
    "                shared[tx] += shared[tx + stride]\n",
    "            cuda.syncthreads()\n",
    "            stride //= 2\n",
    "\n",
    "        # Thread 0 writes the block result\n",
    "        if tx == 0:\n",
    "            block_sums[cuda.blockIdx.x] = shared[0]\n",
    "\n",
    "    print(\"Shared memory kernel defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. `numba_cuda_kernel` – Single-Kernel Wrapper\n",
    "\n",
    "`numba_cuda_kernel` wraps a single `@cuda.jit` kernel so it can be called with\n",
    "JAX GPU arrays.  The kernel receives Numba CUDA device arrays (zero-copy from JAX\n",
    "device memory) and writes results into the output buffer.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "numba_cuda_kernel(\n",
    "    kernel,                    # @cuda.jit decorated function\n",
    "    outs,                      # jax.ShapeDtypeStruct or list thereof\n",
    "    *,\n",
    "    grid=None, block=None,     # explicit CUDA launch config\n",
    "    launch_dims=None,          # OR total threads (auto grid/block)\n",
    "    threads_per_block=256,     # only used with launch_dims\n",
    "    shared_mem=0,              # dynamic shared memory bytes\n",
    ") -> callable\n",
    "```\n",
    "\n",
    "The kernel function signature must be:\n",
    "```python\n",
    "kernel(input1, input2, ..., output1, output2, ...)\n",
    "```\n",
    "Inputs first, then outputs – all as Numba device arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    N = 1024\n",
    "    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)\n",
    "\n",
    "    # Wrap the ReLU kernel\n",
    "    relu_fn = numba_cuda_kernel(\n",
    "        elementwise_relu_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "        grid=4,\n",
    "        block=256,\n",
    "    )\n",
    "\n",
    "    result  = relu_fn(x)\n",
    "    expected = jnp.maximum(x, 0.0)\n",
    "    print(\"ReLU max error:\", float(jnp.max(jnp.abs(result - expected))))\n",
    "\n",
    "    # Wrap the sigmoid kernel using launch_dims (auto grid/block)\n",
    "    sigmoid_fn = numba_cuda_kernel(\n",
    "        elementwise_sigmoid_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "        launch_dims=N,           # launch exactly N threads\n",
    "        threads_per_block=128,\n",
    "    )\n",
    "\n",
    "    result_sig  = sigmoid_fn(x)\n",
    "    expected_sig = 1.0 / (1.0 + jnp.exp(-x))\n",
    "    print(\"Sigmoid max error:\", float(jnp.max(jnp.abs(result_sig - expected_sig))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Multiple Outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    @cuda.jit\n",
    "    def split_kernel(x, pos_out, neg_out):\n",
    "        \"\"\"Split positive and negative parts of x.\"\"\"\n",
    "        i = cuda.grid(1)\n",
    "        if i < x.size:\n",
    "            v = x[i]\n",
    "            pos_out[i] = max(v, 0.0)\n",
    "            neg_out[i] = min(v, 0.0)\n",
    "\n",
    "    N = 512\n",
    "    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)\n",
    "\n",
    "    split_fn = numba_cuda_kernel(\n",
    "        split_kernel,\n",
    "        outs=[\n",
    "            jax.ShapeDtypeStruct((N,), jnp.float32),  # pos_out\n",
    "            jax.ShapeDtypeStruct((N,), jnp.float32),  # neg_out\n",
    "        ],\n",
    "        launch_dims=N,\n",
    "    )\n",
    "\n",
    "    pos, neg = split_fn(x)\n",
    "\n",
    "    print(\"pos_out[:5]:\", pos[:5])   # max(x, 0)\n",
    "    print(\"neg_out[:5]:\", neg[:5])   # min(x, 0)\n",
    "    print(\"pos + neg == x:\", bool(jnp.allclose(pos + neg, x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 JIT Compatibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    @cuda.jit\n",
    "    def add_kernel(x, y, out):\n",
    "        i = cuda.grid(1)\n",
    "        if i < out.size:\n",
    "            out[i] = x[i] + y[i]\n",
    "\n",
    "    N = 256\n",
    "    add_fn = numba_cuda_kernel(\n",
    "        add_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "        launch_dims=N,\n",
    "    )\n",
    "\n",
    "    @jax.jit\n",
    "    def jitted_add(a, b):\n",
    "        return add_fn(a, b)\n",
    "\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.ones(N, dtype=jnp.float32) * 2.0\n",
    "\n",
    "    r = jitted_add(a, b)\n",
    "    print(\"JIT add max error:\", float(jnp.max(jnp.abs(r - (a + b)))))\n",
    "\n",
    "    # Call multiple times (JIT is amortized after first call)\n",
    "    for _ in range(5):\n",
    "        r = jitted_add(a, b)\n",
    "    print(\"Multiple JIT calls OK:\", bool(jnp.allclose(r, a + b)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Launch Configuration: `grid` / `block` vs. `launch_dims`\n",
    "\n",
    "CUDA kernels need a grid/block decomposition specifying how many threads to launch.\n",
    "\n",
    "**Option A – explicit `grid` and `block`:**\n",
    "```python\n",
    "numba_cuda_kernel(kernel, outs=..., grid=8, block=128)\n",
    "# launches 8 blocks × 128 threads = 1024 total threads\n",
    "```\n",
    "\n",
    "**Option B – `launch_dims` (auto-compute):**\n",
    "```python\n",
    "numba_cuda_kernel(kernel, outs=..., launch_dims=1024, threads_per_block=256)\n",
    "# auto: block=256, grid=ceil(1024/256)=4\n",
    "```\n",
    "\n",
    "**2D / 3D launches:**\n",
    "```python\n",
    "# 2D: launch M×N threads\n",
    "numba_cuda_kernel(kernel, outs=..., launch_dims=(M, N))\n",
    "# auto: block=(16,16), grid=(ceil(M/16), ceil(N/16))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    @cuda.jit\n",
    "    def matmul_element_kernel(A, B, C):\n",
    "        \"\"\"C[i,j] = A[i,j] * B[i,j]  (element-wise, 2D grid).\"\"\"\n",
    "        i, j = cuda.grid(2)\n",
    "        if i < C.shape[0] and j < C.shape[1]:\n",
    "            C[i, j] = A[i, j] * B[i, j]\n",
    "\n",
    "    M, N = 64, 64\n",
    "    A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N)\n",
    "    B = jnp.ones((M, N), dtype=jnp.float32) * 2.0\n",
    "\n",
    "    hadamard_fn = numba_cuda_kernel(\n",
    "        matmul_element_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((M, N), jnp.float32),\n",
    "        launch_dims=(M, N),      # 2D launch\n",
    "    )\n",
    "\n",
    "    C = hadamard_fn(A, B)\n",
    "    print(\"2D kernel max error:\", float(jnp.max(jnp.abs(C - A * B))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. `numba_cuda_callable` – Multi-Kernel Wrapper\n",
    "\n",
    "Sometimes one kernel is not enough.  For example, a reduction may need two passes,\n",
    "or a pipeline may require a temporary device buffer between stages.\n",
    "\n",
    "`numba_cuda_callable` wraps an **arbitrary Python function** that can:\n",
    "- Launch multiple `@cuda.jit` kernels\n",
    "- Allocate temporary device memory with `cuda.device_array`\n",
    "- Use the XLA-managed CUDA stream (passed as the last argument)\n",
    "\n",
    "**Required function signature:**\n",
    "```python\n",
    "def my_func(input1, input2, ..., output1, output2, ..., stream):\n",
    "    # input* and output* are Numba CUDA device arrays\n",
    "    # stream is a Numba CUDA stream from XLA\n",
    "    ...\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    # ---- Kernel 1: element-wise square ----\n",
    "    @cuda.jit\n",
    "    def square_kernel(x, temp):\n",
    "        i = cuda.grid(1)\n",
    "        if i < temp.size:\n",
    "            temp[i] = x[i] * x[i]\n",
    "\n",
    "    # ---- Kernel 2: element-wise square root ----\n",
    "    @cuda.jit\n",
    "    def sqrt_kernel(temp, out):\n",
    "        import math\n",
    "        i = cuda.grid(1)\n",
    "        if i < out.size:\n",
    "            out[i] = math.sqrt(temp[i])\n",
    "\n",
    "    # ---- Multi-kernel callable: |x| = sqrt(x^2) ----\n",
    "    def abs_via_two_kernels(x, out, stream):\n",
    "        \"\"\"\n",
    "        Compute |x| using two kernels and a temporary buffer.\n",
    "        Demonstrates multi-kernel pipeline with device allocation.\n",
    "        \"\"\"\n",
    "        n = x.shape[0]\n",
    "        threads = 256\n",
    "        blocks  = (n + threads - 1) // threads\n",
    "\n",
    "        # Temporary buffer on the GPU (freed after this function returns)\n",
    "        temp = cuda.device_array(n, dtype=np.float32)\n",
    "\n",
    "        # Launch both kernels on the XLA-managed stream\n",
    "        square_kernel[blocks, threads, stream](x, temp)\n",
    "        sqrt_kernel  [blocks, threads, stream](temp, out)\n",
    "\n",
    "    N = 512\n",
    "    x = jnp.linspace(-5.0, 5.0, N, dtype=jnp.float32)\n",
    "\n",
    "    abs_fn = numba_cuda_callable(\n",
    "        abs_via_two_kernels,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "    )\n",
    "\n",
    "    result   = abs_fn(x)\n",
    "    expected = jnp.abs(x)\n",
    "    print(\"Multi-kernel |x| max error:\", float(jnp.max(jnp.abs(result - expected))))\n",
    "    print(\"First 5 values :\", result[:5])\n",
    "    print(\"Expected |x|   :\", expected[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Registering with `XLACustomKernel`\n",
    "\n",
    "For production use, register your Numba CUDA kernel as a backend of an\n",
    "`XLACustomKernel` primitive.  This integrates the kernel into JAX's lowering\n",
    "pipeline and allows mixing with other backends (e.g., a Numba CPU fallback).\n",
    "\n",
    "The **kernel generator** pattern is the same as for Warp (see Tutorial 6):\n",
    "it is a Python callable that receives keyword arguments (forwarded from\n",
    "`primitive.bind`) and returns a concrete kernel function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Kernel generator for element-wise leaky ReLU:\n",
    "    #   out[i] = x[i] if x[i] > 0 else alpha * x[i]\n",
    "    # 'alpha' is passed at trace time via kwargs.\n",
    "    # -----------------------------------------------------------------------\n",
    "\n",
    "    def leaky_relu_numba_cuda_generator(**kwargs):\n",
    "        out_info = kwargs['outs'][0]\n",
    "        n        = out_info.shape[0]\n",
    "        alpha    = float(kwargs.get('alpha', 0.01))\n",
    "\n",
    "        @cuda.jit\n",
    "        def leaky_relu_kern(x, out):\n",
    "            i = cuda.grid(1)\n",
    "            if i < out.size:\n",
    "                v = x[i]\n",
    "                out[i] = v if v > 0.0 else alpha * v\n",
    "\n",
    "        def kernel(x):\n",
    "            return numba_cuda_kernel(\n",
    "                leaky_relu_kern,\n",
    "                outs=out_info,\n",
    "                launch_dims=n,\n",
    "            )(x)\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    # Register the primitive\n",
    "    leaky_relu_op = XLACustomKernel('tutorial_numba_cuda_leaky_relu')\n",
    "    leaky_relu_op.def_numba_cuda_kernel(leaky_relu_numba_cuda_generator)\n",
    "\n",
    "    print(\"Registered:\", leaky_relu_op._kernels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    N = 256\n",
    "    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)\n",
    "\n",
    "    @jax.jit\n",
    "    def jitted_leaky_relu(x, alpha=0.1):\n",
    "        return leaky_relu_op(\n",
    "            x,\n",
    "            outs=[jax.ShapeDtypeStruct(x.shape, x.dtype)],\n",
    "            alpha=alpha,\n",
    "        )[0]\n",
    "\n",
    "    r = jitted_leaky_relu(x, alpha=0.1)\n",
    "\n",
    "    expected = jnp.where(x > 0, x, 0.1 * x)\n",
    "    print(\"Leaky ReLU max error:\", float(jnp.max(jnp.abs(r - expected))))\n",
    "    print(\"Values around 0     :\", r[N//2 - 3 : N//2 + 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Neuroscience Example: Parallel Spike Threshold Detection\n",
    "\n",
    "A common operation in spiking neural networks: given membrane potentials `V` and a\n",
    "threshold `V_th`, detect which neurons fire and reset their potentials in-place.\n",
    "\n",
    "We implement this as two Numba CUDA kernels fused via `numba_cuda_callable`:\n",
    "1. **Detect** spikes: `spikes[i] = (V[i] >= V_th)`\n",
    "2. **Reset** potentials: `V_reset[i] = spikes[i] ? V_rest : V[i]`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    @cuda.jit\n",
    "    def detect_spikes_kernel(V, V_th, spikes):\n",
    "        \"\"\"spikes[i] = 1 if V[i] >= V_th[0] else 0.\"\"\"\n",
    "        i = cuda.grid(1)\n",
    "        if i < V.size:\n",
    "            spikes[i] = 1 if V[i] >= V_th[0] else 0\n",
    "\n",
    "    @cuda.jit\n",
    "    def reset_potential_kernel(V, spikes, V_rest, V_out):\n",
    "        \"\"\"V_out[i] = V_rest[0] if spikes[i] else V[i].\"\"\"\n",
    "        i = cuda.grid(1)\n",
    "        if i < V.size:\n",
    "            V_out[i] = V_rest[0] if spikes[i] else V[i]\n",
    "\n",
    "    def lif_step(V, V_th, V_rest, spikes_out, V_out, stream):\n",
    "        \"\"\"One LIF step: detect spikes and reset membrane potential.\"\"\"\n",
    "        n       = V.shape[0]\n",
    "        threads = 256\n",
    "        blocks  = (n + threads - 1) // threads\n",
    "\n",
    "        detect_spikes_kernel[blocks, threads, stream](V, V_th, spikes_out)\n",
    "        reset_potential_kernel[blocks, threads, stream](V, spikes_out, V_rest, V_out)\n",
    "\n",
    "    print(\"LIF step function defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    N_NEURONS = 10_000\n",
    "\n",
    "    rng = np.random.default_rng(0)\n",
    "    V      = jnp.array(rng.uniform(-75.0, -50.0, N_NEURONS).astype(np.float32))\n",
    "    V_th   = jnp.array([-55.0], dtype=jnp.float32)   # threshold (mV)\n",
    "    V_rest = jnp.array([-70.0], dtype=jnp.float32)   # reset potential (mV)\n",
    "\n",
    "    lif_fn = numba_cuda_callable(\n",
    "        lif_step,\n",
    "        outs=[\n",
    "            jax.ShapeDtypeStruct((N_NEURONS,), jnp.int32),    # spikes\n",
    "            jax.ShapeDtypeStruct((N_NEURONS,), jnp.float32),  # V_out\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    spikes, V_out = lif_fn(V, V_th, V_rest)\n",
    "\n",
    "    # Verify\n",
    "    expected_spikes = (V >= V_th[0]).astype(jnp.int32)\n",
    "    expected_V_out  = jnp.where(expected_spikes, V_rest[0], V)\n",
    "\n",
    "    print(f\"Neurons: {N_NEURONS}\")\n",
    "    print(f\"Spikes detected: {int(spikes.sum())} / {N_NEURONS}  ({100*float(spikes.mean()):.1f}%)\")\n",
    "    print(f\"Spike detection error: {int(jnp.sum(spikes != expected_spikes))}\")\n",
    "    print(f\"V_out max error: {float(jnp.max(jnp.abs(V_out - expected_V_out))):.6f} mV\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    import time\n",
    "\n",
    "    @jax.jit\n",
    "    def numba_lif_step(V, V_th, V_rest):\n",
    "        return lif_fn(V, V_th, V_rest)\n",
    "\n",
    "    @jax.jit\n",
    "    def jax_lif_step(V, V_th, V_rest):\n",
    "        spikes = (V >= V_th[0]).astype(jnp.int32)\n",
    "        V_out  = jnp.where(spikes, V_rest[0], V)\n",
    "        return spikes, V_out\n",
    "\n",
    "    # Warm up\n",
    "    jax.block_until_ready(numba_lif_step(V, V_th, V_rest))\n",
    "    jax.block_until_ready(jax_lif_step(V, V_th, V_rest))\n",
    "\n",
    "    N_TRIALS = 500\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(numba_lif_step(V, V_th, V_rest))\n",
    "    numba_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(jax_lif_step(V, V_th, V_rest))\n",
    "    jax_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    print(f\"Numba CUDA LIF step : {numba_time:.3f} ms\")\n",
    "    print(f\"JAX native LIF step : {jax_time:.3f} ms\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Performance Tips\n",
    "\n",
    "### 9.1 Choose the right launch configuration\n",
    "\n",
    "- A warp is 32 threads; use block sizes that are multiples of 32 (128, 256, 512).\n",
    "- Too few threads per block wastes warp slots; too many limits occupancy.\n",
    "- Use `launch_dims` for simple 1-D problems; specify explicit `grid`/`block` for fine control.\n",
    "\n",
    "### 9.2 Minimize thread divergence\n",
    "\n",
    "Threads in the same warp execute in lock-step. Conditional branches that differ between\n",
    "threads (divergence) serialize execution.  Where possible, arrange data so threads\n",
    "in the same warp take the same branch.\n",
    "\n",
    "### 9.3 Use shared memory for reuse\n",
    "\n",
    "If multiple threads access the same data, load it into shared memory first and\n",
    "synchronize with `cuda.syncthreads()` before reading.\n",
    "\n",
    "### 9.4 Cache the wrapped callable\n",
    "\n",
    "Each call to `numba_cuda_kernel` / `numba_cuda_callable` registers a new FFI target.\n",
    "Create the wrapped callable **once** at module level and reuse it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_CUDA_AVAILABLE:\n",
    "    # Good pattern: create the callable once and reuse\n",
    "    @cuda.jit\n",
    "    def exp_decay_kernel(x, decay, out):\n",
    "        import math\n",
    "        i = cuda.grid(1)\n",
    "        if i < out.size:\n",
    "            out[i] = x[i] * math.exp(-decay[0])\n",
    "\n",
    "    N = 1024\n",
    "    # Create once at definition time\n",
    "    _exp_decay_fn = numba_cuda_kernel(\n",
    "        exp_decay_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "        launch_dims=N,\n",
    "    )\n",
    "\n",
    "    @jax.jit\n",
    "    def apply_exp_decay(x, decay):\n",
    "        return _exp_decay_fn(x, decay)\n",
    "\n",
    "    x     = jnp.ones(N, dtype=jnp.float32)\n",
    "    decay = jnp.array([0.1], dtype=jnp.float32)\n",
    "\n",
    "    r = apply_exp_decay(x, decay)\n",
    "    print(\"Exp decay result:\", float(r[0]), \"| Expected:\", float(np.exp(-0.1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Summary\n",
    "\n",
    "In this tutorial we covered:\n",
    "\n",
    "1. **`@cuda.jit`** – Write GPU kernels in Python; use `cuda.grid(ndim)` for thread indices\n",
    "   and bounds-check with `if i < size`.\n",
    "2. **`numba_cuda_kernel`** – Single-kernel JAX wrapper.  Specify launch config via\n",
    "   `(grid, block)` for explicit control or `launch_dims` for automatic decomposition.\n",
    "   Supports 1-D, 2-D, and 3-D launches.\n",
    "3. **`numba_cuda_callable`** – Multi-kernel JAX wrapper.  Your Python function receives\n",
    "   Numba device arrays and the XLA-managed CUDA stream; it can launch multiple kernels\n",
    "   and allocate temporary device memory.\n",
    "4. **`XLACustomKernel.def_numba_cuda_kernel`** – Register a kernel generator as the GPU\n",
    "   backend of a multi-backend custom JAX primitive.\n",
    "5. **Neuroscience application** – LIF spike detection and reset implemented as a\n",
    "   two-kernel callable, showing the key pattern for fused GPU operations.\n",
    "\n",
    "## Next Steps\n",
    "\n",
    "- **Tutorial 6**: Custom GPU operators with Warp\n",
    "- **Tutorial 8**: Custom CPU operators with Numba (`@numba.njit`)\n",
    "\n",
    "## References\n",
    "\n",
    "- [Numba CUDA documentation](https://numba.readthedocs.io/en/stable/cuda/index.html)\n",
    "- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)\n",
    "- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)"
   ]
  }
 ]
}
