{
 "nbformat": 4,
 "nbformat_minor": 4,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom CPU Operators with Numba\n",
    "\n",
    "This tutorial shows how to write high-performance CPU kernels using **Numba's `@njit`**\n",
    "decorator and integrate them into the BrainEvent / JAX ecosystem.\n",
    "\n",
    "[Numba](https://numba.readthedocs.io/) compiles Python functions to native machine code\n",
    "via LLVM, achieving speeds comparable to C/Fortran.  BrainEvent's `numba_kernel` function\n",
    "bridges Numba JIT-compiled functions into JAX via XLA's Foreign Function Interface (FFI),\n",
    "so your Numba kernels become first-class JAX operations compatible with `jax.jit`, `jax.vmap`,\n",
    "and other transforms.\n",
    "\n",
    "## Contents\n",
    "1. Why Numba on CPU?\n",
    "2. Installation and Imports\n",
    "3. Writing Numba JIT Kernels (`@numba.njit`)\n",
    "4. `numba_kernel` – Wrapping for JAX\n",
    "5. Parallel Kernels with `numba.prange`\n",
    "6. Multiple Inputs and Outputs\n",
    "7. Registering with `XLACustomKernel`\n",
    "8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication\n",
    "9. Combining Numba CPU and Warp/Numba-CUDA Backends\n",
    "10. Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Why Numba on CPU?\n",
    "\n",
    "JAX runs on CPU, GPU, and TPU but some algorithms do not map well to the GPU's\n",
    "massively-parallel execution model:\n",
    "\n",
    "- **Sparse / irregular access patterns** – random memory accesses serialize on GPU\n",
    "- **Sequential algorithms** – recurrences that depend on previous iterations\n",
    "- **Small to medium problem sizes** – GPU overhead dominates for small arrays\n",
    "- **CPU-only environments** – laptops, CI servers, edge devices\n",
    "\n",
    "| Property | JAX native (CPU) | Numba (`@njit`) | C extension |\n",
    "|----------|-----------------|-----------------|-------------|\n",
    "| JIT speed | XLA (fast) | LLVM (fast) | Compiled ahead of time |\n",
    "| Python overhead | Yes | Eliminated | Eliminated |\n",
    "| Parallelism | Limited | `prange` / OpenMP | pthread / OpenMP |\n",
    "| Custom loop structure | No | Yes | Yes |\n",
    "| Write in Python | Yes | Yes | No |\n",
    "\n",
    "Numba `@njit` lets you write the inner loop in Python while achieving native\n",
    "performance, and `brainevent.numba_kernel` makes the result a proper JAX primitive.\n",
    "\n",
    "**Requirements:** `pip install numba`  (no GPU needed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Installation and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install if needed:\n",
    "# !pip install numba -U\n",
    "# !pip install brainevent -U\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import brainevent\n",
    "from brainevent import XLACustomKernel, numba_kernel\n",
    "\n",
    "print(f\"JAX version    : {jax.__version__}\")\n",
    "print(f\"JAX backend    : {jax.default_backend()}\")\n",
    "print(f\"BrainEvent     : {brainevent.__version__}\")\n",
    "\n",
    "try:\n",
    "    import numba\n",
    "    print(f\"Numba version  : {numba.__version__}\")\n",
    "    NUMBA_AVAILABLE = True\n",
    "except ImportError:\n",
    "    print(\"Numba not installed. Run: pip install numba\")\n",
    "    NUMBA_AVAILABLE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Writing Numba JIT Kernels (`@numba.njit`)\n",
    "\n",
    "Rules for Numba CPU kernels used with `numba_kernel`:\n",
    "- Decorate with `@numba.njit`  (or `@numba.njit(parallel=True)` for parallelism)\n",
    "- Function signature: `kernel(input1, input2, ..., output1, output2, ...)`\n",
    "  – inputs first, then outputs; all as NumPy arrays (zero-copy from JAX)\n",
    "- **Write results into output arrays** – no return values\n",
    "- Standard Python math, NumPy slicing, and `for` loops all work\n",
    "\n",
    "### 3.1 Simple Element-wise Kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit\n",
    "    def add_kernel(x, y, out):\n",
    "        \"\"\"out[i] = x[i] + y[i]\"\"\"\n",
    "        for i in range(out.size):\n",
    "            out[i] = x[i] + y[i]\n",
    "\n",
    "    @numba.njit\n",
    "    def relu_kernel(x, out):\n",
    "        \"\"\"out[i] = max(x[i], 0.0)\"\"\"\n",
    "        for i in range(out.size):\n",
    "            v = x[i]\n",
    "            out[i] = v if v > 0.0 else 0.0\n",
    "\n",
    "    @numba.njit\n",
    "    def matvec_kernel(A, x, out):\n",
    "        \"\"\"Dense matrix-vector product: out = A @ x\"\"\"\n",
    "        rows, cols = A.shape\n",
    "        for i in range(rows):\n",
    "            total = A.dtype.type(0)\n",
    "            for j in range(cols):\n",
    "                total += A[i, j] * x[j]\n",
    "            out[i] = total\n",
    "\n",
    "    print(\"Numba kernels defined:\", add_kernel, relu_kernel, matvec_kernel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Reduction Kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit\n",
    "    def sum_kernel(x, out):\n",
    "        \"\"\"out[0] = sum(x).\"\"\"\n",
    "        total = x.dtype.type(0)\n",
    "        for i in range(x.size):\n",
    "            total += x[i]\n",
    "        out[0] = total\n",
    "\n",
    "    @numba.njit\n",
    "    def max_kernel(x, out):\n",
    "        \"\"\"out[0] = max(x).\"\"\"\n",
    "        m = x[0]\n",
    "        for i in range(1, x.size):\n",
    "            if x[i] > m:\n",
    "                m = x[i]\n",
    "        out[0] = m\n",
    "\n",
    "    @numba.njit\n",
    "    def running_stats_kernel(x, mean_out, std_out):\n",
    "        \"\"\"Compute mean and std in a single pass.\"\"\"\n",
    "        n = x.size\n",
    "        s = x.dtype.type(0)\n",
    "        for i in range(n):\n",
    "            s += x[i]\n",
    "        mean = s / n\n",
    "        var = x.dtype.type(0)\n",
    "        for i in range(n):\n",
    "            d = x[i] - mean\n",
    "            var += d * d\n",
    "        mean_out[0] = mean\n",
    "        std_out[0]  = (var / n) ** 0.5\n",
    "\n",
    "    print(\"Reduction kernels defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. `numba_kernel` – Wrapping for JAX\n",
    "\n",
    "`numba_kernel` wraps a Numba CPU kernel so it can be called with JAX CPU arrays\n",
    "via XLA's typed FFI protocol.\n",
    "\n",
    "**Signature:**\n",
    "```python\n",
    "numba_kernel(\n",
    "    kernel,              # @numba.njit function\n",
    "    outs,                # jax.ShapeDtypeStruct or list thereof\n",
    "    *,\n",
    "    vmap_method=None,\n",
    "    input_output_aliases=None,\n",
    ") -> callable\n",
    "```\n",
    "\n",
    "The returned callable accepts JAX arrays as inputs and returns JAX arrays as outputs.\n",
    "It is compatible with `jax.jit`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    N = 512\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.ones(N, dtype=jnp.float32) * 3.0\n",
    "\n",
    "    # Create the JAX-callable wrapper\n",
    "    add_fn = numba_kernel(\n",
    "        add_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "    )\n",
    "\n",
    "    result = add_fn(a, b)\n",
    "    # numba_kernel returns a tuple; unwrap if needed\n",
    "    result = result[0] if isinstance(result, tuple) else result\n",
    "\n",
    "    expected = a + b\n",
    "    print(\"Add max error  :\", float(jnp.max(jnp.abs(result - expected))))\n",
    "\n",
    "    # ---- ReLU ----\n",
    "    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)\n",
    "    relu_fn = numba_kernel(\n",
    "        relu_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "    )\n",
    "    r = relu_fn(x)\n",
    "    r = r[0] if isinstance(r, tuple) else r\n",
    "    print(\"ReLU max error :\", float(jnp.max(jnp.abs(r - jnp.maximum(x, 0.0)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    # ---- Reduction ----\n",
    "    N = 10_000\n",
    "    x = jnp.arange(N, dtype=jnp.float32)\n",
    "\n",
    "    sum_fn = numba_kernel(\n",
    "        sum_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((1,), jnp.float32),\n",
    "    )\n",
    "\n",
    "    s = sum_fn(x)\n",
    "    s = s[0] if isinstance(s, tuple) else s\n",
    "    print(f\"Sum: {float(s[0]):.1f}  |  Expected: {float(jnp.sum(x)):.1f}\")\n",
    "\n",
    "    # ---- Multiple outputs (mean and std in one pass) ----\n",
    "    stats_fn = numba_kernel(\n",
    "        running_stats_kernel,\n",
    "        outs=[\n",
    "            jax.ShapeDtypeStruct((1,), jnp.float32),  # mean\n",
    "            jax.ShapeDtypeStruct((1,), jnp.float32),  # std\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    mean_val, std_val = stats_fn(x)\n",
    "    print(f\"Mean: {float(mean_val[0]):.2f}  |  Std: {float(std_val[0]):.2f}\")\n",
    "    print(f\"jnp.mean: {float(jnp.mean(x)):.2f}  |  jnp.std: {float(jnp.std(x)):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 JIT Compatibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    N = 128\n",
    "    add_fn_cached = numba_kernel(\n",
    "        add_kernel,\n",
    "        outs=jax.ShapeDtypeStruct((N,), jnp.float32),\n",
    "    )\n",
    "\n",
    "    @jax.jit\n",
    "    def jitted_pipeline(a, b):\n",
    "        # Mix Numba kernel with standard JAX operations\n",
    "        temp = add_fn_cached(a, b)\n",
    "        temp = temp[0] if isinstance(temp, tuple) else temp\n",
    "        return jnp.sin(temp) * jnp.sqrt(jnp.abs(temp) + 1.0)\n",
    "\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.ones(N, dtype=jnp.float32)\n",
    "\n",
    "    r1 = jitted_pipeline(a, b)\n",
    "    r2 = jitted_pipeline(a * 2, b * 0.5)   # second call reuses compiled code\n",
    "\n",
    "    print(\"JIT pipeline output shape:\", r1.shape)\n",
    "    print(\"First 5 values           :\", r1[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Parallel Kernels with `numba.prange`\n",
    "\n",
    "Adding `parallel=True` to `@numba.njit` and replacing `range` with `numba.prange`\n",
    "enables automatic parallelization across CPU cores using threading.\n",
    "\n",
    "This is the easiest way to exploit multi-core CPUs without writing thread management code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit(parallel=True)\n",
    "    def parallel_add_kernel(x, y, out):\n",
    "        \"\"\"Parallel element-wise add using prange.\"\"\"\n",
    "        for i in numba.prange(out.size):  # parallelized loop\n",
    "            out[i] = x[i] + y[i]\n",
    "\n",
    "    @numba.njit(parallel=True)\n",
    "    def parallel_matvec_kernel(A, x, out):\n",
    "        \"\"\"Parallel matrix-vector product: each row computed by a separate thread.\"\"\"\n",
    "        rows, cols = A.shape\n",
    "        for i in numba.prange(rows):    # parallelize over rows\n",
    "            total = A.dtype.type(0)\n",
    "            for j in range(cols):       # inner loop stays sequential\n",
    "                total += A[i, j] * x[j]\n",
    "            out[i] = total\n",
    "\n",
    "    @numba.njit(parallel=True)\n",
    "    def parallel_exp_decay_kernel(trace, spikes, tau_inv, out):\n",
    "        \"\"\"\n",
    "        Exponential trace update used in STDP:\n",
    "          out[i] = trace[i] * exp(-tau_inv) + spikes[i]\n",
    "        \"\"\"\n",
    "        import math\n",
    "        decay = math.exp(-tau_inv[0])\n",
    "        for i in numba.prange(out.size):\n",
    "            out[i] = trace[i] * decay + spikes[i]\n",
    "\n",
    "    print(\"Parallel kernels defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    import time\n",
    "\n",
    "    N = 1_000_000\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.ones(N, dtype=jnp.float32)\n",
    "\n",
    "    serial_fn   = numba_kernel(add_kernel,          outs=jax.ShapeDtypeStruct((N,), jnp.float32))\n",
    "    parallel_fn = numba_kernel(parallel_add_kernel, outs=jax.ShapeDtypeStruct((N,), jnp.float32))\n",
    "\n",
    "    # Warm up\n",
    "    jax.block_until_ready(serial_fn(a, b))\n",
    "    jax.block_until_ready(parallel_fn(a, b))\n",
    "\n",
    "    N_TRIALS = 20\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(serial_fn(a, b))\n",
    "    serial_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(parallel_fn(a, b))\n",
    "    parallel_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    import os\n",
    "    n_cores = os.cpu_count()\n",
    "    print(f\"N = {N:,}  |  CPU cores: {n_cores}\")\n",
    "    print(f\"Serial   : {serial_time:.2f} ms\")\n",
    "    print(f\"Parallel : {parallel_time:.2f} ms\")\n",
    "    print(f\"Speedup  : {serial_time / parallel_time:.2f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Multiple Inputs and Outputs\n",
    "\n",
    "Numba kernels can take any number of inputs and outputs.\n",
    "The `outs` argument to `numba_kernel` mirrors the output buffers:\n",
    "a single `ShapeDtypeStruct` for one output, a list for multiple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit(parallel=True)\n",
    "    def lif_dynamics_kernel(\n",
    "        V,         # membrane potentials  (N,)\n",
    "        I_ext,     # external current     (N,)\n",
    "        tau_inv,   # 1/tau_m              (1,)\n",
    "        V_th,      # threshold            (1,)\n",
    "        V_rest,    # reset potential      (1,)\n",
    "        dt,        # time step            (1,)\n",
    "        V_out,     # updated potentials   (N,)  – output\n",
    "        spikes,    # spike vector         (N,)  – output\n",
    "    ):\n",
    "        \"\"\"\n",
    "        One Euler step of leaky integrate-and-fire dynamics:\n",
    "          V_out[i] = V[i] + dt * (-(V[i] - V_rest) * tau_inv + I_ext[i])\n",
    "        then threshold and reset.\n",
    "        \"\"\"\n",
    "        th   = V_th[0]\n",
    "        vr   = V_rest[0]\n",
    "        ti   = tau_inv[0]\n",
    "        step = dt[0]\n",
    "\n",
    "        for i in numba.prange(V.size):\n",
    "            v_new = V[i] + step * (-(V[i] - vr) * ti + I_ext[i])\n",
    "            if v_new >= th:\n",
    "                spikes[i]  = 1\n",
    "                V_out[i]   = vr\n",
    "            else:\n",
    "                spikes[i]  = 0\n",
    "                V_out[i]   = v_new\n",
    "\n",
    "    print(\"LIF dynamics kernel defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    N = 5_000\n",
    "    rng = np.random.default_rng(0)\n",
    "\n",
    "    V      = jnp.array(rng.uniform(-75.0, -50.0, N).astype(np.float32))\n",
    "    I_ext  = jnp.array(rng.uniform(0.0,    5.0,  N).astype(np.float32))\n",
    "    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau_m = 20 ms\n",
    "    V_th   = jnp.array([-55.0], dtype=jnp.float32)\n",
    "    V_rest = jnp.array([-70.0], dtype=jnp.float32)\n",
    "    dt     = jnp.array([0.1],   dtype=jnp.float32)         # dt = 0.1 ms\n",
    "\n",
    "    lif_fn = numba_kernel(\n",
    "        lif_dynamics_kernel,\n",
    "        outs=[\n",
    "            jax.ShapeDtypeStruct((N,), jnp.float32),  # V_out\n",
    "            jax.ShapeDtypeStruct((N,), jnp.int32),    # spikes\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    V_new, spikes = lif_fn(V, I_ext, tau_inv, V_th, V_rest, dt)\n",
    "\n",
    "    print(f\"Neurons: {N}\")\n",
    "    print(f\"Spikes : {int(spikes.sum())} ({100*float(spikes.mean()):.1f}%)\")\n",
    "    print(f\"V range: [{float(V_new.min()):.2f}, {float(V_new.max()):.2f}] mV\")\n",
    "\n",
    "    # Verify against JAX reference\n",
    "    V_ref = V + dt[0] * (-(V - V_rest[0]) * tau_inv[0] + I_ext)\n",
    "    spk_ref = (V_ref >= V_th[0]).astype(jnp.int32)\n",
    "    V_ref   = jnp.where(spk_ref, V_rest[0], V_ref)\n",
    "\n",
    "    print(f\"V max error    : {float(jnp.max(jnp.abs(V_new - V_ref))):.6f} mV\")\n",
    "    print(f\"Spike mismatch : {int(jnp.sum(spikes != spk_ref))}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Registering with `XLACustomKernel`\n",
    "\n",
    "For production use, embed your Numba kernel inside a **kernel generator** and\n",
    "register it with `XLACustomKernel`.  The generator receives shape/dtype\n",
    "information forwarded from `primitive.bind` and returns the concrete callable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Kernel generator: exponential trace update (STDP)\n",
    "    #   out[i] = trace[i] * decay + spikes[i]\n",
    "    # -----------------------------------------------------------------------\n",
    "\n",
    "    def exp_trace_numba_generator(**kwargs):\n",
    "        out_info = kwargs['outs'][0]\n",
    "        n        = out_info.shape[0]\n",
    "\n",
    "        @numba.njit(parallel=True)\n",
    "        def trace_kern(trace, spikes, tau_inv, out):\n",
    "            import math\n",
    "            decay = math.exp(-tau_inv[0])\n",
    "            for i in numba.prange(n):\n",
    "                out[i] = trace[i] * decay + spikes[i]\n",
    "\n",
    "        def kernel(trace, spikes, tau_inv):\n",
    "            result = numba_kernel(\n",
    "                trace_kern,\n",
    "                outs=out_info,\n",
    "            )(trace, spikes, tau_inv)\n",
    "            return result if not isinstance(result, tuple) else result\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    # Register the primitive\n",
    "    trace_op = XLACustomKernel('tutorial_numba_exp_trace')\n",
    "    trace_op.def_numba_kernel(exp_trace_numba_generator)\n",
    "\n",
    "    print(\"Registered backends:\", list(trace_op._kernels.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    N = 1000\n",
    "    trace   = jnp.zeros(N, dtype=jnp.float32)\n",
    "    spikes  = jnp.array(np.random.default_rng(1).random(N) < 0.1,\n",
    "                        dtype=jnp.float32)\n",
    "    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau = 20 ms\n",
    "\n",
    "    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)\n",
    "\n",
    "    @jax.jit\n",
    "    def update_trace(trace, spikes, tau_inv):\n",
    "        return trace_op(\n",
    "            trace, spikes, tau_inv,\n",
    "            outs=[out_spec],\n",
    "        )[0]\n",
    "\n",
    "    # Simulate 100 time steps of trace dynamics\n",
    "    import math\n",
    "    decay = math.exp(-float(tau_inv[0]))\n",
    "\n",
    "    trace_history = []\n",
    "    for step in range(100):\n",
    "        spikes = jnp.array(\n",
    "            np.random.default_rng(step).random(N) < 0.05,\n",
    "            dtype=jnp.float32\n",
    "        )\n",
    "        trace = update_trace(trace, spikes, tau_inv)\n",
    "        trace_history.append(float(trace.mean()))\n",
    "\n",
    "    print(f\"Trace stats after 100 steps:\")\n",
    "    print(f\"  Mean  : {float(trace.mean()):.4f}\")\n",
    "    print(f\"  Max   : {float(trace.max()):.4f}\")\n",
    "    print(f\"  Steady-state (theory): {0.05 / (1 - decay):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication\n",
    "\n",
    "A core operation in neural network simulation:\n",
    "given a CSR weight matrix and a float input vector, compute the matrix-vector product.\n",
    "\n",
    "This is naturally sequential per output neuron (row of CSR), making it a good fit\n",
    "for parallel Numba on CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit(parallel=True)\n",
    "    def csr_matvec_numba(\n",
    "        data,       # CSR non-zero values  (nnz,)\n",
    "        indices,    # CSR column indices   (nnz,)\n",
    "        indptr,     # CSR row pointers     (n_rows+1,)\n",
    "        x,          # input vector         (n_cols,)\n",
    "        out,        # output vector        (n_rows,)\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Sparse matrix-vector product (CSR format).\n",
    "        Each row is processed by one thread (parallel over rows).\n",
    "        \"\"\"\n",
    "        n_rows = indptr.size - 1\n",
    "        for i in numba.prange(n_rows):\n",
    "            total = out.dtype.type(0)\n",
    "            for k in range(indptr[i], indptr[i + 1]):\n",
    "                total += data[k] * x[indices[k]]\n",
    "            out[i] = total\n",
    "\n",
    "    def csr_mv_numba_generator(**kwargs):\n",
    "        out_info = kwargs['outs'][0]\n",
    "\n",
    "        def kernel(data, indices, indptr, x):\n",
    "            result = numba_kernel(\n",
    "                csr_matvec_numba,\n",
    "                outs=out_info,\n",
    "            )(data, indices, indptr, x)\n",
    "            return result if not isinstance(result, tuple) else result\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    csr_mv_op = XLACustomKernel('tutorial_numba_csr_matvec')\n",
    "    csr_mv_op.def_numba_kernel(csr_mv_numba_generator)\n",
    "\n",
    "    print(\"CSR MV operator registered.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    import scipy.sparse as sp\n",
    "\n",
    "    N_PRE  = 2000\n",
    "    N_POST = 1000\n",
    "    PROB   = 0.05\n",
    "\n",
    "    rng   = np.random.default_rng(42)\n",
    "    dense = (rng.random((N_POST, N_PRE)) < PROB).astype(np.float32)\n",
    "    dense *= rng.uniform(0.01, 0.5, dense.shape).astype(np.float32)\n",
    "    csr   = sp.csr_matrix(dense)\n",
    "\n",
    "    data    = jnp.array(csr.data,    dtype=jnp.float32)\n",
    "    indices = jnp.array(csr.indices, dtype=jnp.int32)\n",
    "    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)\n",
    "    x       = jnp.array(rng.random(N_PRE).astype(np.float32))\n",
    "\n",
    "    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)\n",
    "\n",
    "    result = csr_mv_op(\n",
    "        data, indices, indptr, x,\n",
    "        outs=[out_spec],\n",
    "    )[0]\n",
    "\n",
    "    expected = jnp.array(dense) @ x\n",
    "    print(f\"Network: {N_PRE} pre -> {N_POST} post  (nnz={csr.nnz})\")\n",
    "    print(f\"Max error vs dense: {float(jnp.max(jnp.abs(result - expected))):.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if NUMBA_AVAILABLE:\n",
    "    import time\n",
    "\n",
    "    @jax.jit\n",
    "    def numba_csr_mv(data, indices, indptr, x):\n",
    "        return csr_mv_op(\n",
    "            data, indices, indptr, x,\n",
    "            outs=[out_spec],\n",
    "        )[0]\n",
    "\n",
    "    @jax.jit\n",
    "    def jax_dense_mv(A, x):\n",
    "        return A @ x\n",
    "\n",
    "    A_jnp = jnp.array(dense)\n",
    "\n",
    "    # Warm up\n",
    "    jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))\n",
    "    jax.block_until_ready(jax_dense_mv(A_jnp, x))\n",
    "\n",
    "    N_TRIALS = 50\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))\n",
    "    numba_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(jax_dense_mv(A_jnp, x))\n",
    "    jax_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    print(f\"Numba CSR MV  : {numba_time:.2f} ms\")\n",
    "    print(f\"JAX dense MV  : {jax_time:.2f} ms\")\n",
    "    print(f\"Speedup       : {jax_time / numba_time:.2f}x  (sparsity: {1 - csr.nnz/(N_PRE*N_POST):.0%})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Combining Numba CPU and GPU Backends\n",
    "\n",
    "The same `XLACustomKernel` primitive can have both a Numba CPU backend and a\n",
    "GPU backend (Warp or Numba CUDA). JAX automatically dispatches to the correct\n",
    "backend based on the device where the arrays live."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import warp\n",
    "    from warp.jax_experimental import jax_kernel as warp_jax_kernel\n",
    "    from brainevent import jaxinfo_to_warpinfo\n",
    "    warp.config.quiet = True\n",
    "    WARP_AVAILABLE = True\n",
    "except ImportError:\n",
    "    WARP_AVAILABLE = False\n",
    "\n",
    "if NUMBA_AVAILABLE:\n",
    "    # CPU backend (already shown above)\n",
    "    @numba.njit(parallel=True)\n",
    "    def scale_numba(x, s, out):\n",
    "        for i in numba.prange(out.size):\n",
    "            out[i] = x[i] * s[0]\n",
    "\n",
    "    def scale_numba_generator(**kwargs):\n",
    "        out_info = kwargs['outs'][0]\n",
    "\n",
    "        def kernel(x, s):\n",
    "            r = numba_kernel(scale_numba, outs=out_info)(x, s)\n",
    "            return r if not isinstance(r, tuple) else r\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    scale_op = XLACustomKernel('tutorial_multi_backend_scale')\n",
    "    scale_op.def_numba_kernel(scale_numba_generator)   # CPU backend\n",
    "\n",
    "    if WARP_AVAILABLE:\n",
    "        def scale_warp_generator(**kwargs):\n",
    "            out_info = kwargs['outs'][0]\n",
    "            n = out_info.shape[0]\n",
    "            t = jaxinfo_to_warpinfo(out_info)\n",
    "            s_type = warp.array(dtype=jaxinfo_to_warpinfo(out_info).dtype, ndim=1)\n",
    "\n",
    "            @warp.kernel\n",
    "            def kern(x: t, s: s_type, out: t):\n",
    "                i = warp.tid()\n",
    "                out[i] = x[i] * s[0]\n",
    "\n",
    "            def kernel(x, s):\n",
    "                fn = warp_jax_kernel(kern, launch_dims=[n], num_outputs=1,\n",
    "                                     output_dims={'out': (n,)})\n",
    "                return fn(x, s)\n",
    "\n",
    "            return kernel\n",
    "\n",
    "        scale_op.def_warp_kernel(scale_warp_generator)  # GPU backend\n",
    "\n",
    "    print(\"Multi-backend scale op registered.\")\n",
    "    print(\"Backends:\", {p: list(b.keys()) for p, b in scale_op._kernels.items()})\n",
    "\n",
    "    # Use it\n",
    "    N = 256\n",
    "    x = jnp.arange(N, dtype=jnp.float32)\n",
    "    s = jnp.array([3.14], dtype=jnp.float32)\n",
    "\n",
    "    r = scale_op(x, s, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])[0]\n",
    "    print(f\"Result matches: {bool(jnp.allclose(r, x * 3.14, atol=1e-5))}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Summary\n",
    "\n",
    "In this tutorial we covered:\n",
    "\n",
    "1. **`@numba.njit`** – JIT-compile Python to native machine code. Kernel signature:\n",
    "   `kernel(input1, ..., output1, ...)` – all NumPy arrays, no return values.\n",
    "2. **`@numba.njit(parallel=True)` + `numba.prange`** – Multi-threaded parallelism\n",
    "   on CPU cores with zero additional code.\n",
    "3. **`numba_kernel(kernel, outs=...)`** – Wrap a Numba kernel as a JAX-callable\n",
    "   via XLA FFI. Returns a function compatible with `jax.jit`.\n",
    "4. **Multiple outputs** – Pass a list of `jax.ShapeDtypeStruct` to `outs` to get\n",
    "   multiple return arrays from a single kernel call.\n",
    "5. **`XLACustomKernel.def_numba_kernel`** – Register a kernel generator as the CPU\n",
    "   backend of a multi-backend custom JAX primitive.\n",
    "6. **Neuroscience applications** – LIF dynamics and sparse CSR matrix-vector product\n",
    "   implemented with parallel Numba, demonstrating realistic use cases.\n",
    "\n",
    "## Key Guidelines\n",
    "\n",
    "- Cache the wrapped callable (do **not** call `numba_kernel` inside `@jax.jit`);\n",
    "  create it once at definition time.\n",
    "- Use `@njit(parallel=True)` + `prange` for outer loops; keep inner loops sequential.\n",
    "- Prefer Numba on CPU for irregular / sparse access patterns; prefer GPU backends\n",
    "  (Warp, Numba CUDA) for large-scale parallel workloads.\n",
    "\n",
    "## Next Steps\n",
    "\n",
    "- **Tutorial 6**: Custom GPU operators with Warp\n",
    "- **Tutorial 7**: Custom GPU operators with Numba CUDA\n",
    "\n",
    "## References\n",
    "\n",
    "- [Numba documentation](https://numba.readthedocs.io/)\n",
    "- [Numba `prange` parallelism](https://numba.readthedocs.io/en/stable/user/parallel.html)\n",
    "- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)\n",
    "- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)"
   ]
  }
 ]
}
