{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatic Differentiation with Units\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/jax_integration/autograd.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/jax_integration/autograd.ipynb)\n",
    "\n",
    "`brainunit.autograd` provides unit-aware automatic differentiation. When you differentiate a \n",
    "function that operates on physical quantities, the derivatives automatically carry the correct \n",
    "units.\n",
    "\n",
    "Available functions:\n",
    "- `grad` — Gradient (first derivative)\n",
    "- `value_and_grad` — Compute value and gradient simultaneously\n",
    "- `jacobian` / `jacrev` / `jacfwd` — Jacobian matrix\n",
    "- `hessian` — Hessian matrix (second derivatives)\n",
    "- `vector_grad` — Element-wise gradient for vector-valued functions"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:19.643370700Z",
     "start_time": "2026-03-04T15:10:18.831701100Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `grad` — First Derivative\n",
    "\n",
    "The gradient of a function `f(x)` with respect to `x` has unit `unit(f) / unit(x)`.\n",
    "\n",
    "For example, if `f(v) = 0.5 * m * v^2` (kinetic energy in Joules) and `v` is in m/s,\n",
    "then `df/dv` has unit `J / (m/s) = kg * m/s` (momentum)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:19.876325900Z",
     "start_time": "2026-03-04T15:10:19.663112100Z"
    }
   },
   "source": [
    "# Kinetic energy: KE = 0.5 * m * v^2\n",
    "mass = 2.0 * u.kilogram\n",
    "\n",
    "def kinetic_energy(v):\n",
    "    return 0.5 * mass * v**2\n",
    "\n",
    "v = 3.0 * u.meter / u.second\n",
    "print('KE:', kinetic_energy(v))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KE: 9. J\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.101189Z",
     "start_time": "2026-03-04T15:10:19.891104100Z"
    }
   },
   "source": [
    "# dKE/dv = m * v  (momentum)\n",
    "dKE_dv = u.autograd.grad(kinetic_energy)\n",
    "print('dKE/dv:', dKE_dv(v))  # 2 * 3 = 6 kg*m/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dKE/dv: 6. kg * m / s\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.155141900Z",
     "start_time": "2026-03-04T15:10:20.105260800Z"
    }
   },
   "source": [
    "# Gravitational potential energy: PE = m * g * h\n",
    "g = 9.81 * u.meter / u.second**2\n",
    "\n",
    "def potential_energy(h):\n",
    "    return mass * g * h\n",
    "\n",
    "# dPE/dh = m * g (force, in Newtons)\n",
    "dPE_dh = u.autograd.grad(potential_energy)\n",
    "print('dPE/dh:', dPE_dh(10.0 * u.meter))  # 2 * 9.81 = 19.62 N"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dPE/dh: 19.62 N\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Differentiating with respect to specific arguments\n",
    "\n",
    "Use `argnums` to select which argument to differentiate with respect to."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.323793400Z",
     "start_time": "2026-03-04T15:10:20.157128700Z"
    }
   },
   "source": [
    "def power_dissipated(V, R):\n",
    "    \"\"\"P = V^2 / R\"\"\"\n",
    "    return V**2 / R\n",
    "\n",
    "V = 12.0 * u.volt\n",
    "R = 4.0 * u.ohm\n",
    "\n",
    "# dP/dV = 2V/R\n",
    "dP_dV = u.autograd.grad(power_dissipated, argnums=0)\n",
    "print('dP/dV:', dP_dV(V, R))  # 2*12/4 = 6 V/ohm = 6 A (watt/volt)\n",
    "\n",
    "# dP/dR = -V^2/R^2\n",
    "dP_dR = u.autograd.grad(power_dissipated, argnums=1)\n",
    "print('dP/dR:', dP_dR(V, R))  # -144/16 = -9 V^2/ohm^2 = -9 W/ohm"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dP/dV: 6. A\n",
      "dP/dR: -9. A^2\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `value_and_grad` — Compute Both at Once\n",
    "\n",
    "When you need both the function value and its gradient, `value_and_grad` is more \n",
    "efficient than calling them separately."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.358480500Z",
     "start_time": "2026-03-04T15:10:20.341149200Z"
    }
   },
   "source": [
    "val_grad_fn = u.autograd.value_and_grad(kinetic_energy)\n",
    "\n",
    "value, gradient = val_grad_fn(3.0 * u.meter / u.second)\n",
    "print('Value (KE):', value)     # 9.0 J\n",
    "print('Gradient (momentum):', gradient)  # 6.0 kg*m/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value (KE): 9. J\n",
      "Gradient (momentum): 6. kg * m / s\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### With auxiliary outputs\n",
    "\n",
    "Use `has_aux=True` when the function returns extra outputs alongside the scalar to differentiate."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.383548400Z",
     "start_time": "2026-03-04T15:10:20.359649Z"
    }
   },
   "source": [
    "def energy_with_info(v):\n",
    "    ke = 0.5 * mass * v**2\n",
    "    momentum = mass * v\n",
    "    return ke, momentum  # (scalar to diff, auxiliary)\n",
    "\n",
    "val_grad_aux = u.autograd.value_and_grad(energy_with_info, has_aux=True)\n",
    "(ke, momentum), grad = val_grad_aux(3.0 * u.meter / u.second)\n",
    "\n",
    "print('KE:', ke)\n",
    "print('Momentum (aux):', momentum)\n",
    "print('Gradient:', grad)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KE: 9. J\n",
      "Momentum (aux): 6. kg * m / s\n",
      "Gradient: 6. kg * m / s\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `jacobian` — Jacobian Matrix\n",
    "\n",
    "For a function `f: R^n -> R^m`, the Jacobian is an `m x n` matrix where \n",
    "`J[i,j] = df_i/dx_j`. The unit of each entry is `unit(f_i) / unit(x_j)`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.688629300Z",
     "start_time": "2026-03-04T15:10:20.386006200Z"
    }
   },
   "source": [
    "# Simple scalar-to-scalar: Jacobian reduces to a scalar (the derivative)\n",
    "def f(x):\n",
    "    return x**3\n",
    "\n",
    "J = u.autograd.jacobian(f)(2.0 * u.meter)\n",
    "print('Jacobian of x^3 at x=2m:', J)  # 3 * 4 = 12 m^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jacobian of x^3 at x=2m: 12. m^2\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.815721100Z",
     "start_time": "2026-03-04T15:10:20.695035400Z"
    }
   },
   "source": [
    "# Forward-mode Jacobian (more efficient when input dim < output dim)\n",
    "J_fwd = u.autograd.jacfwd(f)(2.0 * u.meter)\n",
    "print('jacfwd:', J_fwd)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jacfwd: 12. m^2\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:20.835254200Z",
     "start_time": "2026-03-04T15:10:20.816722700Z"
    }
   },
   "source": [
    "# Reverse-mode Jacobian (more efficient when output dim < input dim)\n",
    "J_rev = u.autograd.jacrev(f)(2.0 * u.meter)\n",
    "print('jacrev:', J_rev)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jacrev: 12. m^2\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `hessian` — Second Derivatives\n",
    "\n",
    "The Hessian matrix contains all second partial derivatives:\n",
    "`H[i,j] = d^2f / (dx_i * dx_j)`.\n",
    "\n",
    "Unit: `unit(f) / (unit(x))^2`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.028344600Z",
     "start_time": "2026-03-04T15:10:20.835254200Z"
    }
   },
   "source": [
    "# f(x) = x^3  ->  f''(x) = 6x\n",
    "H = u.autograd.hessian(f)(2.0 * u.meter)\n",
    "print('Hessian of x^3 at x=2m:', H)  # 6*2 = 12 m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hessian of x^3 at x=2m: 12. m\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.147431600Z",
     "start_time": "2026-03-04T15:10:21.028344600Z"
    }
   },
   "source": [
    "# Quadratic function: f(x) = 0.5 * k * x^2  (spring potential)\n",
    "k = 100.0 * u.newton / u.meter  # spring constant\n",
    "\n",
    "def spring_energy(x):\n",
    "    return 0.5 * k * x**2\n",
    "\n",
    "# First derivative: dU/dx = k*x (force)\n",
    "print('Force:', u.autograd.grad(spring_energy)(0.1 * u.meter))\n",
    "\n",
    "# Hessian (second derivative): d^2U/dx^2 = k (stiffness)\n",
    "print('Stiffness:', u.autograd.hessian(spring_energy)(0.1 * u.meter))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Force: 10. N\n",
      "Stiffness: 100. J / m^2\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `vector_grad` — Element-wise Gradient\n",
    "\n",
    "For a function that maps a vector to a vector element-wise, `vector_grad` computes\n",
    "the gradient for each element independently."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.426794200Z",
     "start_time": "2026-03-04T15:10:21.148933200Z"
    }
   },
   "source": [
    "def square(x):\n",
    "    return x**2\n",
    "\n",
    "x = jnp.array([1., 2., 3., 4.]) * u.meter\n",
    "vg = u.autograd.vector_grad(square)(x)\n",
    "print('vector_grad of x^2:', vg)  # [2, 4, 6, 8] m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vector_grad of x^2: [2. 4. 6. 8.] m\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:21.535232300Z",
     "start_time": "2026-03-04T15:10:21.440792Z"
    }
   },
   "source": "# With return_value=True, get both the gradient and the function value\nvg, val = u.autograd.vector_grad(square, return_value=True)(x)\nprint('Gradients:', vg)\nprint('Values:', val)",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradients: [2. 4. 6. 8.] m\n",
      "Values: [ 1.  4.  9. 16.] m^2\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Example: Optimization with Units\n",
    "\n",
    "Gradient descent on a physical objective function."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:22.480260200Z",
     "start_time": "2026-03-04T15:10:21.535232300Z"
    }
   },
   "source": [
    "# Find the position where total energy is minimized\n",
    "# E(x) = 0.5*k*x^2 + m*g*x  (spring + gravity along x-axis)\n",
    "k_spring = 50.0 * u.newton / u.meter\n",
    "m_obj = 1.0 * u.kilogram\n",
    "g_acc = 9.81 * u.meter / u.second**2\n",
    "\n",
    "def total_energy(x):\n",
    "    return 0.5 * k_spring * x**2 + m_obj * g_acc * x\n",
    "\n",
    "grad_E = u.autograd.grad(total_energy)\n",
    "\n",
    "# Simple gradient descent\n",
    "x = 0.0 * u.meter\n",
    "lr = 0.01 * u.meter / u.newton  # learning rate with correct units\n",
    "\n",
    "for i in range(100):\n",
    "    x = x - lr * grad_E(x)\n",
    "\n",
    "print('Equilibrium position:', x)  # should be ~ -m*g/k = -0.1962 m\n",
    "print('Analytical solution:', -m_obj * g_acc / k_spring)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Equilibrium position: -0.19620001 m\n",
      "Analytical solution: -0.1962 m\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Function | Description | Output Unit |\n",
    "|----------|------------|------------|\n",
    "| `grad(f)` | First derivative | `unit(f) / unit(x)` |\n",
    "| `value_and_grad(f)` | Value + first derivative | `(unit(f), unit(f)/unit(x))` |\n",
    "| `jacobian(f)` | Jacobian matrix | `unit(f_i) / unit(x_j)` |\n",
    "| `hessian(f)` | Hessian matrix | `unit(f) / unit(x)^2` |\n",
    "| `vector_grad(f)` | Element-wise gradient | `unit(f) / unit(x)` |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
