{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "108666acc6bd",
   "metadata": {},
   "source": [
    "# JAX backend (default)\n",
    "\n",
    "JAX is `brainunit`'s default array backend. Every install ships with it because\n",
    "`jax` is a required core dependency. The JAX backend is the only one that\n",
    "supports `brainunit.autograd`, `brainunit.lax`, and `brainunit.sparse`, and it is\n",
    "the only backend that participates in `jax.jit` / `jax.vmap` / `jax.pmap`.\n",
    "\n",
    "This notebook shows the JAX backend in isolation. For the multi-backend\n",
    "story see [overview](overview.ipynb).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "610d36f28fd2",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "```bash\n",
    "pip install brainunit          # core: pulls in jax + numpy\n",
    "pip install brainunit[cpu]     # core + jax[cpu] CPU wheels\n",
    "pip install brainunit[cuda12]  # core + jax[cuda12] for NVIDIA GPU\n",
    "pip install brainunit[tpu]     # core + jax[tpu] for Google TPU\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdcda4bbb905",
   "metadata": {},
   "source": [
    "## Quick start\n",
    "\n",
    "Multiplying a JAX array by a unit produces a JAX-backed `Quantity`. The\n",
    "`.backend` property echoes that back.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "104428bc3404",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:17.577616500Z",
     "start_time": "2026-05-22T03:10:14.219425500Z"
    }
   },
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainunit as u\n",
    "\n",
    "q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.meter)\n",
    "print(q)\n",
    "print('backend =', q.backend)\n",
    "print('(q + q).backend =', (q + q).backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 2. 3.] m\n",
      "backend = jax\n",
      "(q + q).backend = jax\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "id": "7d84fd37e014",
   "metadata": {},
   "source": [
    "## Automatic differentiation\n",
    "\n",
    "`brainunit.autograd.grad` is JAX-only and propagates units through the\n",
    "derivative. The derivative of `x ** 3` w.r.t. a length is an area.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "2cad00fb38aa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:17.730635200Z",
     "start_time": "2026-05-22T03:10:17.586615900Z"
    }
   },
   "source": [
    "f = lambda x: x ** 3\n",
    "x = 3.0 * u.meter\n",
    "u.autograd.grad(f)(x)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(27., \"m^2\")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "id": "704b6b9c0648",
   "metadata": {},
   "source": [
    "## JIT and vmap\n",
    "\n",
    "Quantities are registered as JAX pytrees, so they flow through `jit` and\n",
    "`vmap` without manual unpacking. The compiled function sees the mantissa\n",
    "as the leaf and the unit as static metadata.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "d80524f3a941",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:17.799198300Z",
     "start_time": "2026-05-22T03:10:17.739858200Z"
    }
   },
   "source": [
    "@jax.jit\n",
    "def kinetic_energy(m, v):\n",
    "    return 0.5 * m * v ** 2\n",
    "\n",
    "m = 1.5 * u.kgram\n",
    "v = 4.0 * u.meter / u.second\n",
    "kinetic_energy(m, v)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(12., \"J\")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "id": "33cb4252fd22",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:17.958945300Z",
     "start_time": "2026-05-22T03:10:17.800196800Z"
    }
   },
   "source": [
    "speeds = u.math.arange(0.0 * u.meter / u.second,\n",
    "                       5.0 * u.meter / u.second,\n",
    "                       1.0 * u.meter / u.second)\n",
    "jax.vmap(lambda v: kinetic_energy(1.5 * u.kgram, v))(speeds)\n"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([ 0.    0.75  3.    6.75 12.  ], \"J\")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "id": "efe71af273b7",
   "metadata": {},
   "source": [
    "## Setting JAX as the default explicitly\n",
    "\n",
    "`jax` is already the fallback default, but you can be explicit. Useful when\n",
    "you build a `Quantity` from a Python list (no array yet) and want to pin the\n",
    "backend.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "12e4e4f42c98",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:17.987864900Z",
     "start_time": "2026-05-22T03:10:17.967347400Z"
    }
   },
   "source": [
    "with u.using_backend('jax'):\n",
    "    q = u.Quantity([1.0, 2.0], unit=u.meter)\n",
    "    print(type(q.mantissa).__module__, q.backend)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jaxlib._jax jax\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "id": "baaf9c106c11",
   "metadata": {},
   "source": [
    "## Mixed backends\n",
    "\n",
    "Mixing a JAX-backed quantity with one from another backend falls through the\n",
    "default-backend tiebreaker. By default the result lands on JAX.\n"
   ]
  },
  {
   "cell_type": "code",
   "id": "0839ffec4ee5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-05-22T03:10:18.036753100Z",
     "start_time": "2026-05-22T03:10:17.988863800Z"
    }
   },
   "source": [
    "import numpy as np\n",
    "\n",
    "q_np  = u.Quantity(np.array([1.0]), unit=u.meter)\n",
    "q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)\n",
    "print((q_np + q_jax).backend)         # 'jax' (default tiebreaker)\n",
    "\n",
    "with u.using_backend('numpy'):\n",
    "    print((q_np + q_jax).backend)     # 'numpy'\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jax\n",
      "jax\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "id": "64606d75945b",
   "metadata": {},
   "source": [
    "## See also\n",
    "\n",
    "- [Backends overview](overview.ipynb) — selection rules and capabilities.\n",
    "- `brainunit.autograd`, `brainunit.lax`, `brainunit.sparse` — JAX-only subpackages.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
