{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0c5d34cf2a6a",
   "metadata": {},
   "source": [
    "# Backends overview\n",
    "\n",
    "`brainunit` pairs a physical `Unit` with an array mantissa. The mantissa can\n",
    "live on any one of several array backends, and every unit-aware operation\n",
    "(`brainunit.math`, `brainunit.linalg`, `brainunit.fft`, plain arithmetic)\n",
    "dispatches to the matching backend's array library. You can stay inside one\n",
    "backend end-to-end, mix them, or switch by calling a single conversion\n",
    "method.\n",
    "\n",
    "This page describes the architecture, the supported backends, how selection\n",
    "works, what each backend can and cannot do, and how to install optional\n",
    "backend dependencies. Per-backend notebooks follow it: see [JAX](jax.ipynb),\n",
    "[NumPy](numpy.ipynb), [CuPy](cupy.ipynb), [PyTorch](torch.ipynb),\n",
    "[Dask](dask.ipynb), and [ndonnx](ndonnx.ipynb).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee4ea17b56bf",
   "metadata": {},
   "source": [
    "## Supported backends\n",
    "\n",
    "| Backend  | Mantissa type        | Optional install        | Typical use case                              |\n",
    "|----------|----------------------|-------------------------|-----------------------------------------------|\n",
    "| `jax`    | `jax.Array`          | required (core)         | autograd, JIT, vmap, accelerators (default)   |\n",
    "| `numpy`  | `numpy.ndarray`      | required (core)         | scipy / pandas / sklearn interop, CPU         |\n",
    "| `cupy`   | `cupy.ndarray`       | `brainunit[cupy]`         | NVIDIA GPU arrays, drop-in NumPy replacement  |\n",
    "| `torch`  | `torch.Tensor`       | `brainunit[torch]`        | PyTorch models, CUDA/MPS tensors              |\n",
    "| `dask`   | `dask.array.Array`   | `brainunit[dask]`         | out-of-core / parallel arrays, lazy compute   |\n",
    "| `ndonnx` | `ndonnx.Array`       | `brainunit[ndonnx]`       | symbolic graph building, ONNX export          |\n",
    "\n",
    "`jax` and `numpy` are always available because both are required core\n",
    "dependencies. The other four are opt-in: if you do not install the extra,\n",
    "brainunit still works — it just refuses to dispatch onto that backend and\n",
    "raises `brainunit.BackendError` with the matching `pip install` hint when you\n",
    "ask for one explicitly.\n",
    "\n",
    "Internally, the `numpy`, `cupy`, `torch`, and `dask` namespaces are sourced\n",
    "from\n",
    "[`array_api_compat`](https://github.com/data-apis/array-api-compat) so they\n",
    "all expose the same array-API-standard surface. `jax.numpy` (JAX ≥ 0.9) and\n",
    "`ndonnx` are array-API compatible on their own and are used unwrapped.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50f9fdc29d3e",
   "metadata": {},
   "source": [
    "## How backend selection works\n",
    "\n",
    "For every operation, brainunit asks: *\"which array library should compute the\n",
    "result?\"* The rule:\n",
    "\n",
    "1. Inspect the input mantissas. If exactly one backend kind is present, use\n",
    "   it.\n",
    "2. If inputs mix backends, or there are no array inputs, consult the\n",
    "   thread-local **default backend** set by `set_default_backend(...)` /\n",
    "   `using_backend(...)`.\n",
    "3. If no default is set, fall back to `jax` (the historical default).\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "8edf9baf9dc0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.310528100Z",
     "start_time": "2026-05-22T03:09:31.394830300Z"
    }
   },
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "import brainunit as u\n",
    "\n",
    "q_np  = u.Quantity(np.array([1.0]), unit=u.meter)\n",
    "q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)\n",
    "\n",
    "print('q_np.backend        =', q_np.backend)\n",
    "print('q_jax.backend       =', q_jax.backend)\n",
    "print('(q_np + q_np).bk    =', (q_np + q_np).backend)   # single -> wins\n",
    "print('(q_np + q_jax).bk   =', (q_np + q_jax).backend)  # mixed  -> default\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "q_np.backend        = numpy\n",
      "q_jax.backend       = jax\n",
      "(q_np + q_np).bk    = numpy\n",
      "(q_np + q_jax).bk   = jax\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "cd7fb1b28bb5",
   "metadata": {},
   "source": [
    "Override the tiebreaker with the context manager:\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "a8f70742c4cb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.349283100Z",
     "start_time": "2026-05-22T03:09:35.311963400Z"
    }
   },
   "source": [
    "with u.using_backend('numpy'):\n",
    "    print('inside using_backend:', (q_np + q_jax).backend)  # 'numpy'\n",
    "\n",
    "print('outside:', (q_np + q_jax).backend)                    # back to 'jax'\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inside using_backend: jax\n",
      "outside: jax\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "id": "83eec29a26f9",
   "metadata": {},
   "source": [
    "Or set it for the rest of the program:\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "5712173e911d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.377082600Z",
     "start_time": "2026-05-22T03:09:35.350290100Z"
    }
   },
   "source": [
    "u.set_default_backend('numpy')\n",
    "print(u.get_default_backend())\n",
    "print((q_np + q_jax).backend)\n",
    "\n",
    "u.set_default_backend(None)   # restore default\n",
    "print(u.get_default_backend())\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy\n",
      "jax\n",
      "None\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "a39b98a851e7",
   "metadata": {},
   "source": [
    "The default is a `ContextVar`, so it isolates per-thread and per-task; nested\n",
    "`using_backend(...)` blocks restore the prior value on exit.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a1fc18182c3",
   "metadata": {},
   "source": [
    "## Choosing a backend\n",
    "\n",
    "There is no universally best backend — each one trades capability against\n",
    "ecosystem.\n",
    "\n",
    "- **`jax`** — pick this when you need automatic differentiation, JIT, `vmap`,\n",
    "  or accelerator support out of the box. This is the default and the most\n",
    "  fully integrated backend; everything in `brainunit.autograd`, `brainunit.lax`,\n",
    "  and `brainunit.sparse` requires it.\n",
    "- **`numpy`** — pick this for interop with the broader scientific Python\n",
    "  stack (scipy, pandas, sklearn, matplotlib) where you want eager results\n",
    "  with no JAX tracing. Works on CPU only.\n",
    "- **`cupy`** — pick this when you want a near-drop-in NumPy replacement\n",
    "  running on an NVIDIA GPU and you don't need autodiff. Requires a CUDA\n",
    "  toolkit.\n",
    "- **`torch`** — pick this to embed unit-aware computations inside an existing\n",
    "  PyTorch model. PyTorch's own autograd is preserved through brainunit ops, so\n",
    "  `loss.backward()` works on a quantity-derived loss. `brainunit.autograd`\n",
    "  itself is JAX-only — call `torch.autograd.grad` on the mantissa.\n",
    "- **`dask`** — pick this for arrays that don't fit in memory, or for\n",
    "  embarrassingly parallel array work on a cluster. Operations stay lazy\n",
    "  until you call `.compute()`.\n",
    "- **`ndonnx`** — pick this when you want to build an ONNX graph\n",
    "  symbolically. Operations build the graph rather than executing eagerly.\n",
    "  Still maturing: not every brainunit operation has an ndonnx implementation.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2410de715bc1",
   "metadata": {},
   "source": [
    "## Backend capabilities and limitations\n",
    "\n",
    "Dimensional analysis works on every backend — brainunit tracks units on the\n",
    "Python `Quantity` object, independent of the mantissa library. The\n",
    "limitations below describe what each *array backend* can and cannot do, not\n",
    "the unit system.\n",
    "\n",
    "### `jax` (default)\n",
    "\n",
    "Full feature set. The only backend that supports:\n",
    "\n",
    "- `brainunit.lax.*` — wrappers over `jax.lax` primitives.\n",
    "- `brainunit.autograd.*` — `grad`, `jacobian`, `hessian`.\n",
    "- `brainunit.sparse.*` — `CSR`, `CSC`, `COO` sparse matrices.\n",
    "- `jax.jit`, `jax.vmap`, `jax.pmap` over quantities.\n",
    "\n",
    "### `numpy`\n",
    "\n",
    "Eager CPU computation. `brainunit.math`, `brainunit.linalg`, and `brainunit.fft`\n",
    "all work. JAX-specific subpackages raise `BackendError`.\n",
    "\n",
    "### `cupy`\n",
    "\n",
    "NVIDIA GPU arrays via CUDA. Same general capability as `numpy` for\n",
    "`brainunit.math` / `brainunit.linalg` / `brainunit.fft`, but executed on the GPU.\n",
    "No autograd, no JIT, no `brainunit.lax`.\n",
    "\n",
    "### `torch`\n",
    "\n",
    "PyTorch tensors. `brainunit.math` / `brainunit.linalg` / `brainunit.fft` route\n",
    "through `array_api_compat.torch`. Use `torch.autograd.grad` on the\n",
    "mantissa when you need backward passes — `brainunit.autograd` is JAX-only.\n",
    "\n",
    "### `dask`\n",
    "\n",
    "Lazy arrays. Building a quantity, inspecting `.shape` / `.ndim` / `.dtype`,\n",
    "arithmetic, and most array-API operations stay lazy. Operations that need a\n",
    "concrete Python value — `float(q)`, `int(q)`, `q.tolist()`, `np.asarray(q)`,\n",
    "`hash(q)`, `operator.index(q)` — raise `BackendError`; call\n",
    "`q.mantissa.compute()` first.\n",
    "\n",
    "### `ndonnx`\n",
    "\n",
    "Symbolic / ONNX graph building. Routing is correct for the array-API\n",
    "operations that ndonnx implements. Operations ndonnx hasn't implemented yet\n",
    "surface their own errors unwrapped (brainunit does not catch them). Unit\n",
    "information lives on the `Quantity` and is not encoded in the ONNX graph.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4041eeac2c87",
   "metadata": {},
   "source": [
    "Example of a JAX-only operation refusing a NumPy mantissa:\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "44149214d407",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.474805800Z",
     "start_time": "2026-05-22T03:09:35.389715500Z"
    }
   },
   "source": [
    "from brainunit import BackendError\n",
    "\n",
    "q_np = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "try:\n",
    "    u.lax.slice(q_np, (0,), (1,))\n",
    "except BackendError as exc:\n",
    "    print('expected:', exc)\n",
    "\n",
    "# convert and retry\n",
    "print(u.lax.slice(q_np.to_jax(), (0,), (1,)))\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expected: brainunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.\n",
      "[1.] m\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "3533fad6e3ca",
   "metadata": {},
   "source": [
    "## Optional dependencies and graceful failure\n",
    "\n",
    "Optional backends are detected lazily. The `is_*_array` helpers cache\n",
    "`ImportError` for the lifetime of the process and never raise:\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "e92063ab7e61",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.520901Z",
     "start_time": "2026-05-22T03:09:35.476113800Z"
    }
   },
   "source": [
    "print('is_jax_array(jnp.zeros(1))      =', u.is_jax_array(jnp.zeros(1)))\n",
    "print('is_numpy_array(np.zeros(1))     =', u.is_numpy_array(np.zeros(1)))\n",
    "print('is_cupy_array on a non-cupy obj =', u.is_cupy_array([1, 2, 3]))\n",
    "print('is_torch_array on a non-torch   =', u.is_torch_array([1, 2, 3]))\n",
    "print('is_dask_array on a non-dask     =', u.is_dask_array([1, 2, 3]))\n",
    "print('is_ndonnx_array on non-ndonnx   =', u.is_ndonnx_array([1, 2, 3]))\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_jax_array(jnp.zeros(1))      = True\n",
      "is_numpy_array(np.zeros(1))     = True\n",
      "is_cupy_array on a non-cupy obj = False\n",
      "is_torch_array on a non-torch   = False\n",
      "is_dask_array on a non-dask     = False\n",
      "is_ndonnx_array on non-ndonnx   = False\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "id": "be5f0becfe46",
   "metadata": {},
   "source": [
    "Asking for a backend that isn't installed raises `brainunit.BackendError`,\n",
    "not a bare `ImportError`. The exception message includes the exact install\n",
    "command, so guard around the selection if you want graceful fallback:\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d796963388a6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.533500300Z",
     "start_time": "2026-05-22T03:09:35.522200500Z"
    }
   },
   "source": [
    "def pick_backend():\n",
    "    for name, module in [('torch', 'torch'), ('cupy', 'cupy'),\n",
    "                          ('jax', 'jax'),    ('numpy', 'numpy')]:\n",
    "        try:\n",
    "            __import__(module)\n",
    "            return name\n",
    "        except ImportError:\n",
    "            continue\n",
    "    raise RuntimeError('no array backend available')\n",
    "\n",
    "print('preferred backend:', pick_backend())\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preferred backend: torch\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "id": "7c92d9c7f957",
   "metadata": {},
   "source": [
    "## Conversion between backends\n",
    "\n",
    "Every `Quantity` has a per-backend conversion method. Each one returns a\n",
    "new `Quantity`; the original is untouched. Each one is a no-op (`return\n",
    "self`) if the mantissa is already on the target backend.\n",
    "\n",
    "| Method                                  | Notes                                              |\n",
    "|-----------------------------------------|----------------------------------------------------|\n",
    "| `q.to_jax()`                            | Wraps the mantissa with `jnp.asarray`.             |\n",
    "| `q.to_numpy()`                          | Materializes ndonnx via `unwrap_numpy`.            |\n",
    "| `q.to_cupy(device=None)`                | `device` is a CUDA device index.                   |\n",
    "| `q.to_torch(device=None, dtype=None)`   | `dtype` accepts numpy *or* torch dtypes.           |\n",
    "| `q.to_dask(chunks='auto')`              | Wraps with `dask.array.from_array`.                |\n",
    "| `q.to_ndonnx()`                         | `ndonnx.asarray` on the mantissa.                  |\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d4e8837e30c1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:35.555778600Z",
     "start_time": "2026-05-22T03:09:35.534689300Z"
    }
   },
   "source": [
    "q_np  = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)\n",
    "q_jax = q_np.to_jax()         # NumPy -> JAX\n",
    "q_back = q_jax.to_numpy()     # JAX  -> NumPy\n",
    "\n",
    "print(q_np.backend, '->', q_jax.backend, '->', q_back.backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy -> jax -> numpy\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "id": "7b51c6b5ffe0",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "| Command                                    | Provides                                            |\n",
    "|--------------------------------------------|-----------------------------------------------------|\n",
    "| `pip install brainunit`                      | core `Quantity`, `jax` + `numpy` backends           |\n",
    "| `pip install brainunit[cpu]`                 | core + `jax[cpu]` (pinned CPU wheels)               |\n",
    "| `pip install brainunit[cuda12]`              | core + `jax[cuda12]`                                |\n",
    "| `pip install brainunit[cuda13]`              | core + `jax[cuda13]`                                |\n",
    "| `pip install brainunit[tpu]`                 | core + `jax[tpu]`                                   |\n",
    "| `pip install brainunit[cupy]`                | adds `cupy-cuda12x` for the CuPy backend            |\n",
    "| `pip install brainunit[torch]`               | adds `torch>=2.0` for the PyTorch backend           |\n",
    "| `pip install brainunit[dask]`                | adds `dask[array]` for the Dask backend             |\n",
    "| `pip install brainunit[ndonnx]`              | adds `ndonnx` for the symbolic backend              |\n",
    "| `pip install brainunit[all]`                 | shorthand for `[cupy,torch,dask,ndonnx]`            |\n",
    "\n",
    "JAX is a required dependency — every install includes the JAX backend. The\n",
    "`[cpu]` / `[cuda12]` / `[cuda13]` / `[tpu]` extras pin the JAX accelerator\n",
    "build; pick at most one. The `[cupy]` / `[torch]` / `[dask]` / `[ndonnx]`\n",
    "extras are independent and can be combined freely.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93d65902cb7d",
   "metadata": {},
   "source": [
    "## See also\n",
    "\n",
    "- Per-backend notebooks: [JAX](jax.ipynb), [NumPy](numpy.ipynb),\n",
    "  [CuPy](cupy.ipynb), [PyTorch](torch.ipynb), [Dask](dask.ipynb),\n",
    "  [ndonnx](ndonnx.ipynb).\n",
    "- API reference: `set_default_backend`, `using_backend`,\n",
    "  `get_default_backend`, `is_*_array`, `BackendError`, `Quantity.backend`,\n",
    "  `Quantity.to_*`.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
