{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cccd053f1ced",
   "metadata": {},
   "source": [
    "# PyTorch backend\n",
    "\n",
    "The PyTorch backend lets you embed unit-aware computations inside an\n",
    "existing PyTorch model. PyTorch's own autograd is preserved through\n",
    "brainunit operations, so `tensor.backward()` works on a quantity-derived\n",
    "loss.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc16c406401f",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "```bash\n",
    "pip install brainunit[torch]\n",
    "```\n",
    "\n",
    "The extra pins `torch>=2.0`. Pick a CPU or CUDA wheel via PyTorch's own\n",
    "install matrix if you need a specific accelerator build.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "151bf960726b",
   "metadata": {},
   "source": [
    "## Graceful import\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "82b9407e01f5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:49.831247900Z",
     "start_time": "2026-05-22T03:09:45.099621900Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "\n",
    "try:\n",
    "    import torch\n",
    "    HAVE_TORCH = True\n",
    "except ImportError:\n",
    "    HAVE_TORCH = False\n",
    "    print('torch not installed; install with: pip install brainunit[torch]')\n"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "e758af4d0f0a",
   "metadata": {},
   "source": [
    "## Quick start\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "b4eb60989c87",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:50.455858Z",
     "start_time": "2026-05-22T03:09:50.243528700Z"
    }
   },
   "source": [
    "if HAVE_TORCH:\n",
    "    q = u.Quantity(torch.tensor([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "    print(q)\n",
    "    print('backend =', q.backend)\n",
    "    print('(q + q).backend =', (q + q).backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 2. 3.] m\n",
      "backend = torch\n",
      "(q + q).backend = torch\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "id": "f813a03bdca0",
   "metadata": {},
   "source": [
    "## Conversion\n",
    "\n",
    "`Quantity.to_torch(device=..., dtype=...)` accepts either a torch dtype\n",
    "(`torch.float32`) or a numpy dtype (`np.float32`) — brainunit translates the\n",
    "latter automatically.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d82caec68c4c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:50.488432Z",
     "start_time": "2026-05-22T03:09:50.466387800Z"
    }
   },
   "source": [
    "if HAVE_TORCH:\n",
    "    import numpy as np\n",
    "    q_cpu = u.Quantity([1.0, 2.0], unit=u.meter)\n",
    "    q_f64 = q_cpu.to_torch(dtype=np.float32)\n",
    "    print(q_f64.mantissa.dtype)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "e4696fc6a454",
   "metadata": {},
   "source": [
    "## Gradients with PyTorch autograd\n",
    "\n",
    "`brainunit.autograd.grad` is JAX-only. For PyTorch use `torch.autograd.grad`\n",
    "(or `.backward()`) on the mantissa of the result. Units propagate through\n",
    "brainunit operations even though the gradient itself is computed by torch.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "ba2b7a3d5e12",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:50.504998100Z",
     "start_time": "2026-05-22T03:09:50.490424900Z"
    }
   },
   "source": [
    "if HAVE_TORCH:\n",
    "    x = torch.tensor([1.0, 2.0], requires_grad=True)\n",
    "    q = u.Quantity(x, unit=u.meter) ** 2     # area\n",
    "    loss = q.mantissa.sum()\n",
    "    grads, = torch.autograd.grad(loss, x)\n",
    "    print('grads:', grads)                   # 2 * x\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grads: tensor([2., 4.])\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "b930a9a48d5f",
   "metadata": {},
   "source": [
    "## Requesting the backend explicitly\n",
    "\n",
    "Same `BackendError` semantics as the other optional backends.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "e34efddeaf17",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:50.513363300Z",
     "start_time": "2026-05-22T03:09:50.504998100Z"
    }
   },
   "source": [
    "from brainunit import BackendError\n",
    "\n",
    "if not HAVE_TORCH:\n",
    "    try:\n",
    "        with u.using_backend('torch'):\n",
    "            u.Quantity([1.0], unit=u.meter)\n",
    "    except BackendError as exc:\n",
    "        print('expected without torch:', exc)\n"
   ],
   "outputs": [],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "id": "1e54e8c43a84",
   "metadata": {},
   "source": [
    "## Limitations\n",
    "\n",
    "- `brainunit.autograd`, `brainunit.lax`, `brainunit.sparse` are JAX-only.\n",
    "- Convert tensors with `.to_jax()` / `.to_numpy()` when you need those.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
