{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JAX Transforms: JIT and vmap with Units\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/jax_integration/jax_transforms.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/jax_integration/jax_transforms.ipynb)\n",
    "\n",
    "brainunit `Quantity` objects work seamlessly with JAX's core transformations:\n",
    "\n",
    "- **`jax.jit`** — Just-in-time compilation for faster execution\n",
    "- **`jax.vmap`** — Automatic vectorization (batching)\n",
    "- **Composing transforms** — Combine jit, vmap, and grad"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.535232300Z",
     "start_time": "2026-03-04T15:10:20.644890800Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `jax.jit` — JIT Compilation\n",
    "\n",
    "JIT compilation traces and compiles a function for faster repeated execution.\n",
    "`Quantity` objects are fully supported — units are tracked through compilation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.722242500Z",
     "start_time": "2026-03-04T15:10:21.535232300Z"
    }
   },
   "source": [
    "# A physics computation\n",
    "def kinetic_energy(m, v):\n",
    "    return 0.5 * m * v**2\n",
    "\n",
    "# JIT-compiled version\n",
    "jit_ke = jax.jit(kinetic_energy)\n",
    "\n",
    "m = 5.0 * u.kilogram\n",
    "v = 10.0 * u.meter / u.second\n",
    "\n",
    "result = jit_ke(m, v)\n",
    "print('KE:', result)  # 250 J"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KE: 250. J\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.771050300Z",
     "start_time": "2026-03-04T15:10:21.723243400Z"
    }
   },
   "source": [
    "# Decorator syntax also works\n",
    "@jax.jit\n",
    "def coulomb_force(q1, q2, r):\n",
    "    k = 8.9875e9 * u.newton * u.meter**2 / u.coulomb**2\n",
    "    return k * q1 * q2 / r**2\n",
    "\n",
    "q1 = 1.6e-19 * u.coulomb  # electron charge\n",
    "q2 = 1.6e-19 * u.coulomb\n",
    "r = 1e-10 * u.meter       # ~1 angstrom\n",
    "\n",
    "print('Coulomb force:', coulomb_force(q1, q2, r))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coulomb force: 2.3007996e-08 N\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JIT with array computations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.865756200Z",
     "start_time": "2026-03-04T15:10:21.772061Z"
    }
   },
   "source": [
    "@jax.jit\n",
    "def rms_voltage(v_samples):\n",
    "    \"\"\"Root-mean-square voltage.\"\"\"\n",
    "    return u.math.sqrt(u.math.mean(v_samples**2))\n",
    "\n",
    "samples = jnp.array([1.0, -2.0, 3.0, -1.5, 2.5]) * u.volt\n",
    "print('RMS voltage:', rms_voltage(samples))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMS voltage: 2.1213202 V\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `jax.vmap` — Automatic Vectorization\n",
    "\n",
    "`vmap` transforms a function that operates on single values into one that operates\n",
    "on batches, without writing explicit loops."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.032540200Z",
     "start_time": "2026-03-04T15:10:21.866765Z"
    }
   },
   "source": [
    "# Compute kinetic energy for a batch of velocities\n",
    "m = 2.0 * u.kilogram\n",
    "velocities = jnp.array([1., 2., 3., 4., 5.]) * u.meter / u.second\n",
    "\n",
    "# Without vmap: would need a loop\n",
    "# With vmap: automatic batching\n",
    "batch_ke = jax.vmap(lambda v: 0.5 * m * v**2)\n",
    "energies = batch_ke(velocities)\n",
    "print('Batch KE:', energies)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch KE: [ 1.  4.  9. 16. 25.] J\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.441757400Z",
     "start_time": "2026-03-04T15:10:22.047798900Z"
    }
   },
   "source": [
    "# vmap over vector norms\n",
    "def vector_norm(v):\n",
    "    return u.math.sqrt(u.math.sum(v**2))\n",
    "\n",
    "# Batch of 3D vectors\n",
    "vectors = jnp.array([\n",
    "    [1., 0., 0.],\n",
    "    [0., 1., 0.],\n",
    "    [3., 4., 0.],\n",
    "    [1., 1., 1.]\n",
    "]) * u.meter\n",
    "\n",
    "norms = jax.vmap(vector_norm)(vectors)\n",
    "print('Norms:', norms)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Norms: [1.        1.        5.        1.73205078] m\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### vmap with multiple arguments"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.551738500Z",
     "start_time": "2026-03-04T15:10:22.489924800Z"
    }
   },
   "source": [
    "# Ohm's law for multiple resistors: V = I * R\n",
    "def ohm_law(I, R):\n",
    "    return I * R\n",
    "\n",
    "currents = jnp.array([0.1, 0.2, 0.5, 1.0]) * u.ampere\n",
    "resistances = jnp.array([100., 50., 20., 10.]) * u.ohm\n",
    "\n",
    "# vmap over both arguments\n",
    "voltages = jax.vmap(ohm_law)(currents, resistances)\n",
    "print('Voltages:', voltages)  # all 10V"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Voltages: [10. 10. 10. 10.] V\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.592317300Z",
     "start_time": "2026-03-04T15:10:22.551738500Z"
    }
   },
   "source": [
    "# vmap with in_axes: batch over currents only, same resistance for all\n",
    "R_fixed = 100.0 * u.ohm\n",
    "voltages_fixed_R = jax.vmap(ohm_law, in_axes=(0, None))(currents, R_fixed)\n",
    "print('Voltages (fixed R):', voltages_fixed_R)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Voltages (fixed R): [ 10.  20.  50. 100.] V\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### vmap for matrix operations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.800705800Z",
     "start_time": "2026-03-04T15:10:22.593526100Z"
    }
   },
   "source": [
    "# Apply a transformation matrix to a batch of vectors\n",
    "rotation_90 = jnp.array([[0., -1.], [1., 0.]])  # dimensionless rotation matrix\n",
    "\n",
    "def rotate(v):\n",
    "    return rotation_90 @ v\n",
    "\n",
    "points = jnp.array([[1., 0.], [0., 1.], [1., 1.], [2., 3.]]) * u.meter\n",
    "\n",
    "rotated = jax.vmap(rotate)(points)\n",
    "print('Original points:')\n",
    "print(points)\n",
    "print('Rotated 90 degrees:')\n",
    "print(rotated)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original points:\n",
      "[[1. 0.]\n",
      " [0. 1.]\n",
      " [1. 1.]\n",
      " [2. 3.]] m\n",
      "Rotated 90 degrees:\n",
      "[[ 0.  1.]\n",
      " [-1.  0.]\n",
      " [-1.  1.]\n",
      " [-3.  2.]] m\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Composing Transforms\n",
    "\n",
    "JAX transforms compose naturally. You can combine `jit`, `vmap`, and `grad`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.883579900Z",
     "start_time": "2026-03-04T15:10:22.800705800Z"
    }
   },
   "source": [
    "# jit + vmap: fast batched computation\n",
    "fast_batch_ke = jax.jit(jax.vmap(lambda v: 0.5 * (2.0 * u.kilogram) * v**2))\n",
    "\n",
    "vs = jnp.linspace(0., 10., 5) * u.meter / u.second\n",
    "print('Fast batch KE:', fast_batch_ke(vs))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fast batch KE: [  0.     6.25  25.    56.25 100.  ] J\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:23.050267Z",
     "start_time": "2026-03-04T15:10:22.884918800Z"
    }
   },
   "source": [
    "# vmap + grad: batch of gradients\n",
    "def spring_force(x):\n",
    "    k = 100.0 * u.newton / u.meter\n",
    "    return -0.5 * k * x**2\n",
    "\n",
    "# Gradient of spring energy for each position\n",
    "batch_grad = jax.vmap(u.autograd.grad(spring_force))\n",
    "\n",
    "positions = jnp.array([0.0, 0.1, 0.2, 0.3, 0.4]) * u.meter\n",
    "forces = batch_grad(positions)\n",
    "print('Positions:', positions)\n",
    "print('Forces:', forces)  # F = -kx"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positions: [0.  0.1 0.2 0.30000001 0.40000001] m\n",
      "Forces: [ -0.       -10.       -20.       -30.00000191 -40.      ] N\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:23.097347200Z",
     "start_time": "2026-03-04T15:10:23.051283500Z"
    }
   },
   "source": [
    "# jit + vmap + grad: maximum performance\n",
    "fast_batch_grad = jax.jit(jax.vmap(u.autograd.grad(spring_force)))\n",
    "print('Fast batch forces:', fast_batch_grad(positions))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fast batch forces: [ -0.       -10.       -20.       -30.00000191 -40.      ] N\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Transform | Purpose | Example |\n",
    "|-----------|---------|--------|\n",
    "| `jax.jit(f)` | Compile for speed | `jit_f(5.0 * u.meter)` |\n",
    "| `jax.vmap(f)` | Automatic batching | `vmap(f)(batch_of_quantities)` |\n",
    "| `jax.vmap(f, in_axes=(0, None))` | Batch some args | Fixed args use `None` |\n",
    "| `jit(vmap(grad(f)))` | Compose transforms | Fast batched gradients |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
