{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ce02840707ba",
   "metadata": {},
   "source": [
    "# CuPy backend\n",
    "\n",
    "[CuPy](https://cupy.dev/) is a near drop-in replacement for NumPy that runs\n",
    "on NVIDIA GPUs via CUDA. Use it when you want GPU acceleration for\n",
    "array-API operations and you don't need JAX autodiff/JIT.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f22c83e8da9c",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "```bash\n",
    "pip install brainunit[cupy]\n",
    "```\n",
    "\n",
    "Requires a working CUDA toolkit; the `cupy-cuda12x` wheel is pulled in by\n",
    "the extra. If you have CUDA 11, install `cupy-cuda11x` manually instead.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b34510f8b123",
   "metadata": {},
   "source": [
    "## Graceful import\n",
    "\n",
    "If CuPy isn't installed (most CI runners and laptops without an NVIDIA\n",
    "GPU), the snippets below skip cleanly rather than crashing.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "bd5ddc5676fa",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "import brainunit as u\n",
    "\n",
    "try:\n",
    "    import cupy\n",
    "    HAVE_CUPY = True\n",
    "except ImportError:\n",
    "    HAVE_CUPY = False\n",
    "    print('cupy not installed; install with: pip install brainunit[cupy]')\n",
    "\n",
    "print('is_cupy_array on a non-cupy object:', u.is_cupy_array([1, 2, 3]))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12f9a361ab8c",
   "metadata": {},
   "source": [
    "## Quick start\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "a9dd53200824",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if HAVE_CUPY:\n",
    "    q = u.Quantity(cupy.array([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "    print(q)\n",
    "    print('backend =', q.backend)\n",
    "    print('(q + q).backend =', (q + q).backend)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22fea3f57a34",
   "metadata": {},
   "source": [
    "## Math operations\n",
    "\n",
    "`brainunit.math` dispatches to `array_api_compat.cupy`, executing on the GPU.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "bcdace876e93",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if HAVE_CUPY:\n",
    "    x = u.Quantity(cupy.linspace(0.0, cupy.pi, 5), unit=u.UNITLESS)\n",
    "    print(u.math.sin(x))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e61886867c",
   "metadata": {},
   "source": [
    "## Converting between backends\n",
    "\n",
    "`Quantity.to_cupy(device=...)` moves the mantissa to the chosen GPU.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d2bc61836cd2",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if HAVE_CUPY:\n",
    "    import numpy as np\n",
    "    q_cpu = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)\n",
    "    q_gpu = q_cpu.to_cupy(device=0)\n",
    "    print('mantissa lives on device', q_gpu.mantissa.device)\n",
    "    # round-trip back to NumPy\n",
    "    print(q_gpu.to_numpy())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8568ee50bb9",
   "metadata": {},
   "source": [
    "## Requesting the backend explicitly\n",
    "\n",
    "If you ask for the CuPy backend when CuPy isn't installed, brainunit raises\n",
    "`BackendError` (not a bare `ImportError`) with the install hint.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "705010d03930",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "from brainunit import BackendError\n",
    "\n",
    "try:\n",
    "    with u.using_backend('cupy'):\n",
    "        u.Quantity([1.0, 2.0], unit=u.meter)\n",
    "except BackendError as exc:\n",
    "    print('expected without cupy:', exc)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90fdeea8f965",
   "metadata": {},
   "source": [
    "## Limitations\n",
    "\n",
    "- CuPy has no autograd. `brainunit.autograd` is JAX-only.\n",
    "- `brainunit.lax` and `brainunit.sparse` are JAX-only.\n",
    "- Move data to NumPy or JAX with `.to_numpy()` / `.to_jax()` for those.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
