{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "87c138b4779e",
   "metadata": {},
   "source": [
    "# NumPy backend\n",
    "\n",
    "The NumPy backend runs eagerly on CPU and is always available — `numpy` is\n",
    "a core dependency. Pick it for interop with the broader scientific Python\n",
    "stack (scipy, pandas, sklearn, matplotlib) when JAX tracing would get in\n",
    "the way.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38aee184d417",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "Nothing extra is needed; NumPy ships with every brainunit install.\n",
    "\n",
    "```bash\n",
    "pip install brainunit\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4ac9215f233",
   "metadata": {},
   "source": [
    "## Quick start\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "048fb5e120ae",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:09:58.976374100Z",
     "start_time": "2026-05-22T03:09:57.934789700Z"
    }
   },
   "source": [
    "import numpy as np\n",
    "import brainunit as u\n",
    "\n",
    "q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "print(q)\n",
    "print('backend =', q.backend)\n",
    "print('(q + q).backend =', (q + q).backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 2. 3.] m\n",
      "backend = numpy\n",
      "(q + q).backend = numpy\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "9677e4b3083b",
   "metadata": {},
   "source": [
    "## Math, linalg, FFT\n",
    "\n",
    "All three subpackages dispatch to NumPy when the mantissa is a NumPy array.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "45cf917abdd9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:00.869721200Z",
     "start_time": "2026-05-22T03:09:58.977371Z"
    }
   },
   "source": [
    "x = u.Quantity(np.linspace(0.0, np.pi, 5), unit=u.UNITLESS)\n",
    "u.math.sin(x)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.00000000e+00, 7.07106781e-01, 1.00000000e+00, 7.07106781e-01,\n",
       "       1.22464680e-16])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "id": "e72d61c35219",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:01.028421200Z",
     "start_time": "2026-05-22T03:10:00.901274100Z"
    }
   },
   "source": [
    "A = u.Quantity(np.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.meter)\n",
    "u.linalg.norm(A)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(5.477226, \"m\")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "id": "806b5f2b2522",
   "metadata": {},
   "source": [
    "## Setting NumPy as the default\n",
    "\n",
    "Use `using_backend('numpy')` to keep a block of code on NumPy, including\n",
    "quantities built from Python lists.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d80bacca49f9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:01.047471600Z",
     "start_time": "2026-05-22T03:10:01.029429Z"
    }
   },
   "source": [
    "with u.using_backend('numpy'):\n",
    "    q = u.Quantity([1.0, 2.0], unit=u.meter)\n",
    "    print(type(q.mantissa).__module__, q.backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy numpy\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "6bd231849fd5",
   "metadata": {},
   "source": [
    "## NumPy ufunc interop\n",
    "\n",
    "Standard NumPy ufuncs preserve units and enforce dimensional consistency.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d6ec40228072",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:01.067200400Z",
     "start_time": "2026-05-22T03:10:01.048471200Z"
    }
   },
   "source": [
    "a = u.Quantity(np.array([1.0]), unit=u.meter)\n",
    "b = u.Quantity(np.array([2.0]), unit=u.meter)\n",
    "np.add(a, b)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([3.], \"m\")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "id": "6080d3ec085e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:01.078372300Z",
     "start_time": "2026-05-22T03:10:01.067200400Z"
    }
   },
   "source": [
    "from brainunit import UnitMismatchError\n",
    "\n",
    "c = u.Quantity(np.array([1.0]), unit=u.second)\n",
    "try:\n",
    "    np.add(a, c)\n",
    "except UnitMismatchError as exc:\n",
    "    print('expected:', exc)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expected: Cannot convert to a unit with different dimensions. (units are s and m).\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "id": "c9698d5e2af5",
   "metadata": {},
   "source": [
    "## JAX-only subpackages\n",
    "\n",
    "`brainunit.lax`, `brainunit.autograd`, and `brainunit.sparse` need JAX primitives.\n",
    "Calling them on a NumPy-backed quantity raises `BackendError`. Convert with\n",
    "`q.to_jax()` first.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "9974cc05a064",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:01.148081300Z",
     "start_time": "2026-05-22T03:10:01.079364700Z"
    }
   },
   "source": [
    "from brainunit import BackendError\n",
    "\n",
    "q = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "try:\n",
    "    u.lax.slice(q, (0,), (1,))\n",
    "except BackendError as exc:\n",
    "    print('expected:', exc)\n",
    "\n",
    "u.lax.slice(q.to_jax(), (0,), (1,))\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "expected: brainunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Quantity([1.], \"m\")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "id": "4e430620c085",
   "metadata": {},
   "source": [
    "## See also\n",
    "\n",
    "- [Backends overview](overview.ipynb) — supported backends and selection.\n",
    "- [JAX backend](jax.ipynb) — the default backend.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
