{
 "nbformat": 4,
 "nbformat_minor": 4,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom GPU Operators with Warp\n",
    "\n",
    "This tutorial shows how to write custom GPU kernels using **NVIDIA Warp** and integrate them\n",
    "into the BrainEvent / JAX ecosystem.\n",
    "\n",
    "[NVIDIA Warp](https://github.com/NVIDIA/warp) is a Python framework for high-performance\n",
    "GPU kernel authoring. Kernels are written in Python-like syntax, JIT-compiled to CUDA PTX,\n",
    "and can be called seamlessly from JAX via `warp.jax_experimental.jax_kernel`.\n",
    "\n",
    "## Contents\n",
    "1. Why Warp?\n",
    "2. Installation and Imports\n",
    "3. Writing Your First Warp Kernel\n",
    "4. Type Annotations – `jaxinfo_to_warpinfo` / `jaxtype_to_warptype`\n",
    "5. Calling Warp Kernels from JAX\n",
    "6. In-place (accumulation) vs. Pure-output Patterns\n",
    "7. Registering Kernels with `XLACustomKernel`\n",
    "8. Neuroscience Example: Sparse Synaptic Input Accumulation\n",
    "9. Summary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Why Warp?\n",
    "\n",
    "| Feature | Warp | Raw CUDA C++ |\n",
    "|---------|------|--------------|\n",
    "| Language | Python-like syntax | C++ |\n",
    "| Compilation | Automatic JIT | Manual |\n",
    "| JAX integration | Built-in (`jax_kernel`) | Manual XLA FFI |\n",
    "| Autodiff | Limited (scalar ops) | Manual |\n",
    "| Best for | Custom GPU ops in Python | Maximum control |\n",
    "\n",
    "Warp is the recommended path when you want GPU acceleration without leaving Python.\n",
    "BrainEvent's `XLACustomKernel` infrastructure makes it trivial to register a Warp kernel\n",
    "as a backend for any custom JAX primitive.\n",
    "\n",
    "**Requirements:**\n",
    "- NVIDIA GPU with CUDA\n",
    "- `pip install warp-lang` (installs as `import warp`)\n",
    "- JAX with GPU support (`pip install jax[cuda12]`)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Installation and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install if needed\n",
    "# !pip install warp-lang -U\n",
    "# !pip install brainevent[cuda12] -U\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import brainevent\n",
    "from brainevent import XLACustomKernel, jaxinfo_to_warpinfo, jaxtype_to_warptype\n",
    "\n",
    "print(f\"JAX version    : {jax.__version__}\")\n",
    "print(f\"JAX backend    : {jax.default_backend()}\")\n",
    "print(f\"BrainEvent     : {brainevent.__version__}\")\n",
    "\n",
    "try:\n",
    "    import warp\n",
    "    from warp.jax_experimental import jax_kernel\n",
    "    warp.config.quiet = True\n",
    "    print(f\"Warp version   : {warp.__version__}\")\n",
    "    WARP_AVAILABLE = True\n",
    "except ImportError:\n",
    "    print(\"Warp not installed. Run: pip install warp-lang\")\n",
    "    WARP_AVAILABLE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Writing Your First Warp Kernel\n",
    "\n",
    "A Warp kernel is a Python function decorated with `@warp.kernel`. Key rules:\n",
    "- **No Python data structures** – only Warp scalars and arrays\n",
    "- **Thread index** obtained via `warp.tid()` (replaces `blockIdx * blockDim + threadIdx` in CUDA C)\n",
    "- **Array types** must be annotated using `warp.array(dtype=..., ndim=...)`\n",
    "- The kernel body runs **once per thread**, so you typically launch one thread per element\n",
    "\n",
    "### 3.1 Element-wise ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    @warp.kernel\n",
    "    def relu_kernel(\n",
    "        x:   warp.array(dtype=warp.float32, ndim=1),\n",
    "        out: warp.array(dtype=warp.float32, ndim=1),\n",
    "    ):\n",
    "        i = warp.tid()           # thread index = element index\n",
    "        out[i] = warp.max(x[i], warp.float32(0.0))\n",
    "\n",
    "    print(\"relu_kernel defined successfully\")\n",
    "    print(f\"Kernel type: {type(relu_kernel)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Calling the Kernel via `jax_kernel`\n",
    "\n",
    "`jax_kernel` wraps a Warp kernel so it can be called with JAX arrays.\n",
    "\n",
    "**Signature:**\n",
    "```python\n",
    "fn = jax_kernel(\n",
    "    warp_kernel,\n",
    "    launch_dims=[n],         # total threads to launch per dimension\n",
    "    num_outputs=1,           # how many output arrays the kernel writes\n",
    "    output_dims={'out': (n,)} # shape of each output (allocated by Warp)\n",
    ")\n",
    "result = fn(x)  # pass only input arrays; outputs are returned\n",
    "```\n",
    "\n",
    "There are two output modes:\n",
    "- **`output_dims`** – Warp allocates the output buffer; you only pass inputs.\n",
    "- **`in_out_argnames`** – You pass a pre-allocated (e.g., `jnp.zeros`) buffer; Warp writes into it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    N = 1024\n",
    "    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)\n",
    "\n",
    "    # Build the JAX-callable wrapper\n",
    "    relu_fn = jax_kernel(\n",
    "        relu_kernel,\n",
    "        launch_dims=[N],\n",
    "        num_outputs=1,\n",
    "        output_dims={'out': (N,)},\n",
    "    )\n",
    "\n",
    "    # Call it – returns a tuple of output arrays\n",
    "    (result,) = relu_fn(x)\n",
    "\n",
    "    # Verify against JAX reference\n",
    "    expected = jnp.maximum(x, 0.0)\n",
    "    print(\"Max error:\", float(jnp.max(jnp.abs(result - expected))))\n",
    "    print(\"First 8 values:\", result[:8])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Type Annotations – `jaxinfo_to_warpinfo` / `jaxtype_to_warptype`\n",
    "\n",
    "When embedding a Warp kernel inside a **kernel generator** (a function that receives\n",
    "shape/dtype information at trace time), you need to create the Warp type annotations\n",
    "dynamically.  BrainEvent provides two helpers:\n",
    "\n",
    "```python\n",
    "from brainevent import jaxinfo_to_warpinfo, jaxtype_to_warptype\n",
    "\n",
    "# Convert jax.ShapeDtypeStruct  ->  warp.array(dtype=..., ndim=...)\n",
    "warp_arr_type = jaxinfo_to_warpinfo(jax.ShapeDtypeStruct((1024,), jnp.float32))\n",
    "\n",
    "# Convert numpy/JAX dtype  ->  warp scalar type\n",
    "warp_scalar_type = jaxtype_to_warptype(jnp.float32)  # -> warp.float32\n",
    "```\n",
    "\n",
    "These utilities support: `float16`, `float32`, `float64`, `int8`–`int64`, `uint8`–`uint64`, `bool`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    import jax\n",
    "\n",
    "    for jax_dtype in [jnp.float32, jnp.float64, jnp.int32, jnp.bool_]:\n",
    "        warp_type = jaxtype_to_warptype(jax_dtype)\n",
    "        info = jax.ShapeDtypeStruct((8, 4), jax_dtype)\n",
    "        warp_arr = jaxinfo_to_warpinfo(info)\n",
    "        print(f\"  jnp.{jax_dtype.__name__:<8} -> warp scalar: {warp_type}  |  warp array: {warp_arr}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Kernel Generators – Dynamic Kernel Construction\n",
    "\n",
    "When integrating with `XLACustomKernel`, kernels are not defined statically.\n",
    "Instead you define a **kernel generator**: a plain Python function that receives\n",
    "shape/dtype keyword arguments (forwarded from `primitive.bind`) and returns a\n",
    "callable that runs the actual computation.\n",
    "\n",
    "This pattern allows the same generator to handle different dtypes and shapes\n",
    "without re-registering the primitive.\n",
    "\n",
    "### 5.1 Template for a Warp Kernel Generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    def my_relu_kernel_generator(**kwargs):\n",
    "        \"\"\"\n",
    "        Kernel generator for element-wise ReLU.\n",
    "\n",
    "        kwargs contains whatever was passed to XLACustomKernel.__call__,\n",
    "        e.g. kwargs['outs'] = [jax.ShapeDtypeStruct(shape, dtype)]\n",
    "        \"\"\"\n",
    "        # --- 1. Extract shape/dtype information from kwargs ---------------\n",
    "        out_info = kwargs['outs'][0]            # jax.ShapeDtypeStruct\n",
    "        n = out_info.shape[0]\n",
    "\n",
    "        # --- 2. Build Warp type annotations dynamically -------------------\n",
    "        x_warp_type   = jaxinfo_to_warpinfo(out_info)   # same dtype for input\n",
    "        out_warp_type = jaxinfo_to_warpinfo(out_info)\n",
    "\n",
    "        # --- 3. Define the @warp.kernel with dynamic type annotations -----\n",
    "        @warp.kernel\n",
    "        def relu_kern(\n",
    "            x:   x_warp_type,\n",
    "            out: out_warp_type,\n",
    "        ):\n",
    "            i = warp.tid()\n",
    "            out[i] = warp.max(x[i], out_warp_type.dtype(0.0))\n",
    "\n",
    "        # --- 4. Return the concrete kernel function -----------------------\n",
    "        def kernel(x):\n",
    "            fn = jax_kernel(\n",
    "                relu_kern,\n",
    "                launch_dims=[n],\n",
    "                num_outputs=1,\n",
    "                output_dims={'out': (n,)},\n",
    "            )\n",
    "            return fn(x)\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    print(\"Kernel generator defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. In-place (Accumulation) vs. Pure-output Patterns\n",
    "\n",
    "Many neuroscience operations **scatter-add** values into an output buffer\n",
    "(e.g., synaptic current accumulation). Warp handles this via atomic operations\n",
    "and the `in_out_argnames` mechanism.\n",
    "\n",
    "### 6.1 Pure output (Warp allocates)\n",
    "\n",
    "```python\n",
    "fn = jax_kernel(kernel, launch_dims=[N], num_outputs=1, output_dims={'out': (N,)})\n",
    "result, = fn(x)  # only pass inputs\n",
    "```\n",
    "\n",
    "### 6.2 In-place / accumulation (caller provides buffer)\n",
    "\n",
    "```python\n",
    "fn = jax_kernel(kernel, launch_dims=[M], num_outputs=1, in_out_argnames=['acc'])\n",
    "result, = fn(x, jnp.zeros((N,), dtype))  # pass input THEN the initial output buffer\n",
    "```\n",
    "\n",
    "The `in_out_argnames` list tells Warp which arguments are both input and output,\n",
    "enabling atomic operations inside the kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    # Scatter-add example: for each non-zero element in 'values',\n",
    "    # add values[i] * scale to acc[targets[i]].\n",
    "\n",
    "    N_SRC = 512   # source elements\n",
    "    N_DST = 128   # destination (output) size\n",
    "\n",
    "    @warp.kernel\n",
    "    def scatter_add_kernel(\n",
    "        values:  warp.array(dtype=warp.float32, ndim=1),\n",
    "        targets: warp.array(dtype=warp.int32,   ndim=1),\n",
    "        scale:   warp.array(dtype=warp.float32, ndim=1),  # 1-element array\n",
    "        acc:     warp.array(dtype=warp.float32, ndim=1),  # in-place output\n",
    "    ):\n",
    "        i = warp.tid()\n",
    "        # Atomic add is thread-safe – multiple threads may target the same slot\n",
    "        warp.atomic_add(acc, targets[i], values[i] * scale[0])\n",
    "\n",
    "    # Create test data\n",
    "    rng = np.random.default_rng(0)\n",
    "    values  = jnp.array(rng.random(N_SRC).astype(np.float32))\n",
    "    targets = jnp.array(rng.integers(0, N_DST, N_SRC).astype(np.int32))\n",
    "    scale   = jnp.array([2.0], dtype=jnp.float32)\n",
    "\n",
    "    # Build callable with in-place accumulator\n",
    "    scatter_fn = jax_kernel(\n",
    "        scatter_add_kernel,\n",
    "        launch_dims=[N_SRC],\n",
    "        num_outputs=1,\n",
    "        in_out_argnames=['acc'],       # 'acc' is both input and output\n",
    "    )\n",
    "\n",
    "    # Run: pass (values, targets, scale, initial_acc)\n",
    "    init_acc = jnp.zeros(N_DST, dtype=jnp.float32)\n",
    "    (result,) = scatter_fn(values, targets, scale, init_acc)\n",
    "\n",
    "    # Verify with NumPy reference\n",
    "    ref = np.zeros(N_DST, dtype=np.float32)\n",
    "    np.add.at(ref, np.array(targets), np.array(values) * 2.0)\n",
    "    print(\"Scatter-add max error:\", float(jnp.max(jnp.abs(result - jnp.array(ref)))))\n",
    "    print(\"Result sum:\", float(result.sum()), \"| Expected:\", float(ref.sum()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Registering Kernels with `XLACustomKernel`\n",
    "\n",
    "`XLACustomKernel` is BrainEvent's central abstraction for multi-backend custom\n",
    "JAX primitives. It lets you register different backend implementations\n",
    "(Warp, Numba, Pallas, …) for the same logical operation, then dispatch to the\n",
    "right one at runtime.\n",
    "\n",
    "**Workflow:**\n",
    "1. Create an `XLACustomKernel` instance with a unique name\n",
    "2. Register your Warp kernel generator via `def_warp_kernel()`\n",
    "3. (Optionally) register a CPU fallback via `def_numba_kernel()`\n",
    "4. Call the primitive with `kernel(x, outs=[...])`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Step 1: Define the kernel generator\n",
    "    # -----------------------------------------------------------------------\n",
    "    def warp_scale_add_generator(**kwargs):\n",
    "        \"\"\"Element-wise: out[i] = a[i] * b[i] + c[i]\"\"\"\n",
    "        out_info = kwargs['outs'][0]\n",
    "        n        = out_info.shape[0]\n",
    "        t        = jaxinfo_to_warpinfo(out_info)\n",
    "\n",
    "        @warp.kernel\n",
    "        def kern(\n",
    "            a:   t,\n",
    "            b:   t,\n",
    "            c:   t,\n",
    "            out: t,\n",
    "        ):\n",
    "            i = warp.tid()\n",
    "            out[i] = a[i] * b[i] + c[i]\n",
    "\n",
    "        def run(a, b, c):\n",
    "            fn = jax_kernel(kern, launch_dims=[n], num_outputs=1,\n",
    "                            output_dims={'out': (n,)})\n",
    "            return fn(a, b, c)\n",
    "\n",
    "        return run\n",
    "\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Step 2: Create and register the primitive\n",
    "    # -----------------------------------------------------------------------\n",
    "    scale_add_op = XLACustomKernel('tutorial_warp_scale_add')\n",
    "    scale_add_op.def_warp_kernel(warp_scale_add_generator)\n",
    "\n",
    "    print(\"Registered backends:\", scale_add_op._kernels)\n",
    "    print(\"Default backends  :\", scale_add_op.defaults)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Step 3: Call the primitive\n",
    "    # -----------------------------------------------------------------------\n",
    "    N = 256\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.full(N, 2.0, dtype=jnp.float32)\n",
    "    c = jnp.ones(N, dtype=jnp.float32)\n",
    "\n",
    "    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)\n",
    "    result   = scale_add_op(a, b, c, outs=[out_spec])\n",
    "\n",
    "    expected = a * b + c\n",
    "    print(\"Max error:\", float(jnp.max(jnp.abs(result[0] - expected))))\n",
    "    print(\"First 5  :\", result[0][:5])\n",
    "\n",
    "    # -----------------------------------------------------------------------\n",
    "    # Step 4: Use inside jax.jit (the primitive is JIT-compatible)\n",
    "    # -----------------------------------------------------------------------\n",
    "    @jax.jit\n",
    "    def jitted_op(a, b, c):\n",
    "        return scale_add_op(a, b, c, outs=[jax.ShapeDtypeStruct(a.shape, a.dtype)])[0]\n",
    "\n",
    "    r = jitted_op(a, b, c)\n",
    "    print(\"JIT result matches:\", bool(jnp.allclose(r, expected)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Neuroscience Example: Sparse Synaptic Input Accumulation\n",
    "\n",
    "A classic operation in spiking neural network simulation:\n",
    "given a binary spike vector `spikes` (shape `[N_pre]`) and a CSR weight matrix\n",
    "`(data, indices, indptr)`, compute the postsynaptic current\n",
    "\n",
    "$$I_{\\text{post}}[j] = \\sum_{i:\\, \\text{spikes}[i]>0} W[\\text{ptr}_{i}..\\text{ptr}_{i+1}]$$\n",
    "\n",
    "We implement this with a Warp kernel that:\n",
    "1. Iterates over pre-synaptic neurons in parallel\n",
    "2. Skips silent neurons (no spike)\n",
    "3. Atomically accumulates weights into the postsynaptic current buffer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    def csr_binary_mv_warp_generator(**kwargs):\n",
    "        \"\"\"\n",
    "        Kernel generator for CSR × binary-vector multiplication.\n",
    "        Signature: kernel(weights, indices, indptr, spikes) -> post_current\n",
    "        \"\"\"\n",
    "        weight_info  = kwargs['weight_info']\n",
    "        spike_info   = kwargs['spike_info']\n",
    "        indices_info = kwargs['indices_info']\n",
    "        indptr_info  = kwargs['indptr_info']\n",
    "        n_pre        = indptr_info.shape[0] - 1\n",
    "        n_post       = kwargs['n_post']\n",
    "        out_dtype    = kwargs['outs'][0].dtype\n",
    "\n",
    "        # Build Warp type descriptors\n",
    "        w_type      = jaxinfo_to_warpinfo(weight_info)\n",
    "        idx_type    = jaxinfo_to_warpinfo(indices_info)\n",
    "        indptr_type = jaxinfo_to_warpinfo(indptr_info)\n",
    "        spk_type    = jaxinfo_to_warpinfo(spike_info)\n",
    "        out_type    = warp.array(dtype=jaxtype_to_warptype(out_dtype), ndim=1)\n",
    "\n",
    "        @warp.kernel\n",
    "        def mv_kern(\n",
    "            weights: w_type,\n",
    "            indices: idx_type,\n",
    "            indptr:  indptr_type,\n",
    "            spikes:  spk_type,\n",
    "            posts:   out_type,\n",
    "        ):\n",
    "            i = warp.tid()              # one thread per pre-synaptic neuron\n",
    "            if spikes[i]:               # skip silent neurons\n",
    "                w = weights[0]          # scalar weight (homogeneous)\n",
    "                for j in range(indptr[i], indptr[i + 1]):\n",
    "                    warp.atomic_add(posts, indices[j], w)\n",
    "\n",
    "        def kernel(weights, indices, indptr, spikes):\n",
    "            fn = jax_kernel(\n",
    "                mv_kern,\n",
    "                launch_dims=[n_pre],\n",
    "                num_outputs=1,\n",
    "                in_out_argnames=['posts'],\n",
    "            )\n",
    "            return fn(weights, indices, indptr, spikes,\n",
    "                      jnp.zeros(n_post, dtype=out_dtype))\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    print(\"CSR binary MV kernel generator defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    import scipy.sparse as sp\n",
    "\n",
    "    # Build a random CSR connectivity matrix\n",
    "    N_PRE  = 1000\n",
    "    N_POST = 500\n",
    "    PROB   = 0.05   # 5 % connection probability\n",
    "    W      = 0.1    # homogeneous weight\n",
    "\n",
    "    rng = np.random.default_rng(42)\n",
    "    dense = (rng.random((N_PRE, N_POST)) < PROB).astype(np.float32) * W\n",
    "    csr   = sp.csr_matrix(dense)\n",
    "\n",
    "    data   = jnp.array([W], dtype=jnp.float32)          # scalar weight\n",
    "    indices = jnp.array(csr.indices, dtype=jnp.int32)\n",
    "    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)\n",
    "\n",
    "    # Generate binary spikes (10 % firing rate)\n",
    "    spikes = jnp.array(rng.random(N_PRE) < 0.10, dtype=jnp.bool_)\n",
    "\n",
    "    # Register the primitive\n",
    "    csr_mv_op = XLACustomKernel('tutorial_warp_csr_mv')\n",
    "    csr_mv_op.def_warp_kernel(csr_binary_mv_warp_generator)\n",
    "\n",
    "    # Build output spec and call\n",
    "    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)\n",
    "\n",
    "    result = csr_mv_op(\n",
    "        data, indices, indptr, spikes,\n",
    "        outs=[out_spec],\n",
    "        # extra kwargs forwarded to the generator:\n",
    "        weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),\n",
    "        spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),\n",
    "        indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),\n",
    "        indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),\n",
    "        n_post       = N_POST,\n",
    "    )\n",
    "\n",
    "    # Reference: dense matmul\n",
    "    spikes_f  = spikes.astype(jnp.float32)\n",
    "    expected  = spikes_f @ jnp.array(dense)\n",
    "\n",
    "    print(f\"Network: {N_PRE} pre -> {N_POST} post  |  {int(spikes.sum())} spikes\")\n",
    "    print(f\"Max error vs dense reference: {float(jnp.max(jnp.abs(result[0] - expected))):.6f}\")\n",
    "    print(f\"Post current range: [{float(result[0].min()):.3f}, {float(result[0].max()):.3f}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if WARP_AVAILABLE:\n",
    "    import time\n",
    "\n",
    "    @jax.jit\n",
    "    def warp_mv(data, indices, indptr, spikes):\n",
    "        return csr_mv_op(\n",
    "            data, indices, indptr, spikes,\n",
    "            outs=[out_spec],\n",
    "            weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),\n",
    "            spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),\n",
    "            indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),\n",
    "            indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),\n",
    "            n_post       = N_POST,\n",
    "        )[0]\n",
    "\n",
    "    @jax.jit\n",
    "    def dense_mv(spikes_f, dense):\n",
    "        return spikes_f @ dense\n",
    "\n",
    "    # Warm up\n",
    "    jax.block_until_ready(warp_mv(data, indices, indptr, spikes))\n",
    "    jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))\n",
    "\n",
    "    N_TRIALS = 200\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(warp_mv(data, indices, indptr, spikes))\n",
    "    warp_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    t0 = time.time()\n",
    "    for _ in range(N_TRIALS):\n",
    "        jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))\n",
    "    dense_time = (time.time() - t0) / N_TRIALS * 1000\n",
    "\n",
    "    print(f\"Warp sparse kernel : {warp_time:.3f} ms\")\n",
    "    print(f\"JAX dense matmul   : {dense_time:.3f} ms\")\n",
    "    print(f\"Speedup            : {dense_time / warp_time:.2f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Multiple Backends with `XLACustomKernel`\n",
    "\n",
    "You can register multiple backends for the same operation and switch at runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CPU fallback using Numba (demonstrated here even if GPU is unavailable)\n",
    "try:\n",
    "    import numba\n",
    "    from brainevent import numba_kernel\n",
    "    NUMBA_AVAILABLE = True\n",
    "except ImportError:\n",
    "    NUMBA_AVAILABLE = False\n",
    "\n",
    "if NUMBA_AVAILABLE:\n",
    "    @numba.njit(parallel=True)\n",
    "    def _scale_add_numba(a, b, c, out):\n",
    "        for i in numba.prange(out.size):\n",
    "            out[i] = a[i] * b[i] + c[i]\n",
    "\n",
    "    def numba_scale_add_generator(**kwargs):\n",
    "        out_info = kwargs['outs'][0]\n",
    "\n",
    "        def kernel(a, b, c):\n",
    "            return numba_kernel(_scale_add_numba, outs=out_info)(a, b, c)\n",
    "\n",
    "        return kernel\n",
    "\n",
    "    # Create op with both GPU (Warp) and CPU (Numba) backends\n",
    "    multi_backend_op = XLACustomKernel('tutorial_multi_backend_scale_add')\n",
    "\n",
    "    if WARP_AVAILABLE:\n",
    "        multi_backend_op.def_warp_kernel(warp_scale_add_generator)   # GPU\n",
    "\n",
    "    multi_backend_op.def_numba_kernel(numba_scale_add_generator)     # CPU\n",
    "\n",
    "    print(\"Registered backends:\", list(multi_backend_op._kernels.keys()))\n",
    "\n",
    "    # On GPU, Warp is default; on CPU, Numba is used automatically\n",
    "    N = 128\n",
    "    a = jnp.arange(N, dtype=jnp.float32)\n",
    "    b = jnp.full(N, 3.0, dtype=jnp.float32)\n",
    "    c = jnp.ones(N, dtype=jnp.float32)\n",
    "    r = multi_backend_op(a, b, c, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])\n",
    "    print(\"Result matches:\", bool(jnp.allclose(r[0], a * b + c)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Summary\n",
    "\n",
    "In this tutorial we covered:\n",
    "\n",
    "1. **`@warp.kernel`** – Write GPU kernels in Python-like syntax; use `warp.tid()` for the thread index.\n",
    "2. **`jax_kernel`** – Wrap a Warp kernel so JAX can call it with `jax.Array` inputs.\n",
    "   - `output_dims` mode: Warp allocates the output buffer.\n",
    "   - `in_out_argnames` mode: caller provides the initial buffer (needed for atomic accumulation).\n",
    "3. **`jaxinfo_to_warpinfo` / `jaxtype_to_warptype`** – Convert JAX dtype/shape info to Warp types\n",
    "   for dynamic kernel construction inside kernel generators.\n",
    "4. **`XLACustomKernel.def_warp_kernel`** – Register a Warp kernel generator as the GPU backend\n",
    "   of a multi-backend custom JAX primitive.\n",
    "5. **Neuroscience application** – Sparse CSR × binary-spike matrix-vector product implemented\n",
    "   with Warp atomic operations, demonstrating the key pattern used throughout BrainEvent.\n",
    "\n",
    "## Next Steps\n",
    "\n",
    "- **Tutorial 7**: Custom GPU operators with Numba CUDA (`@cuda.jit`)\n",
    "- **Tutorial 8**: Custom CPU operators with Numba (`@numba.njit`)\n",
    "\n",
    "## References\n",
    "\n",
    "- [NVIDIA Warp documentation](https://nvidia.github.io/warp/)\n",
    "- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)\n",
    "- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)"
   ]
  }
 ]
}
