{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "19cec5d4",
   "metadata": {},
   "source": [
    "# Physical units\n",
    "\n",
    "**What you'll learn / who it's for (simulation *and* training).** How\n",
    "``brainpy.state`` uses ``brainunit`` to attach physical units to every quantity,\n",
    "how to construct models unit-safely, how unit-aware initializers work, and how the\n",
    "unit system turns silent modeling mistakes into errors raised at construction time."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c45a648c",
   "metadata": {},
   "source": [
    "## Why units belong in the model\n",
    "\n",
    "A membrane time constant of ``10`` is meaningless: 10 milliseconds and 10 seconds\n",
    "differ by three orders of magnitude and produce completely different dynamics.\n",
    "Brain models are dense with such quantities — millivolts, milliseconds,\n",
    "nanosiemens, milliamps — and mixing their scales is one of the most common, most\n",
    "silent sources of wrong results.\n",
    "\n",
    "``brainpy.state`` carries units through the entire model with\n",
    "[``brainunit``](https://brainunit.readthedocs.io). A unitful value is a\n",
    "``Quantity`` = magnitude × unit. Arithmetic checks dimensions, conversions are\n",
    "explicit, and an incompatible combination raises **before** a single time step\n",
    "runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "db0a51c8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:12.235750Z",
     "iopub.status.busy": "2026-06-17T09:11:12.235527Z",
     "iopub.status.idle": "2026-06-17T09:11:16.319051Z",
     "shell.execute_reply": "2026-06-17T09:11:16.318160Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainpy\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d184903",
   "metadata": {},
   "source": [
    "## Quantities: magnitude × unit\n",
    "\n",
    "Build a quantity by multiplying a number (or array) by a unit. The units you meet\n",
    "constantly in point-neuron modeling:\n",
    "\n",
    "- **Voltage** — ``u.mV`` (millivolt)\n",
    "- **Time** — ``u.ms`` (millisecond)\n",
    "- **Conductance** — ``u.nS`` (nanosiemens), ``u.mS`` (millisiemens)\n",
    "- **Current** — ``u.mA`` (milliamp), ``u.nA`` (nanoamp)\n",
    "- **Frequency** — ``u.Hz``"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2a093349",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:16.321488Z",
     "iopub.status.busy": "2026-06-17T09:11:16.321125Z",
     "iopub.status.idle": "2026-06-17T09:11:16.460707Z",
     "shell.execute_reply": "2026-06-17T09:11:16.459925Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10. ms\n",
      "-50. mV\n",
      "weight has dimension: m^-2 kg^-1 s^3 A^2\n"
     ]
    }
   ],
   "source": [
    "tau = 10. * u.ms\n",
    "V_th = -50. * u.mV\n",
    "weight = 0.6 * u.mS\n",
    "current = 20. * u.mA\n",
    "\n",
    "print(tau)\n",
    "print(V_th)\n",
    "print('weight has dimension:', weight.dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d66e2c52",
   "metadata": {},
   "source": [
    "### Arithmetic is dimensionally checked\n",
    "\n",
    "Adding compatible quantities works (and rescales as needed); combining\n",
    "incompatible ones is an error — which is exactly what you want."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6debbb70",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:16.462119Z",
     "iopub.status.busy": "2026-06-17T09:11:16.461814Z",
     "iopub.status.idle": "2026-06-17T09:11:16.470125Z",
     "shell.execute_reply": "2026-06-17T09:11:16.469424Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100 ms + 0.5 s = 600.0 ms\n",
      "V / I = -32.5 Mohm\n"
     ]
    }
   ],
   "source": [
    "# Compatible: different time units add correctly.\n",
    "total = 100. * u.ms + 0.5 * u.second\n",
    "print('100 ms + 0.5 s =', total.to_decimal(u.ms), 'ms')\n",
    "\n",
    "# Derived units fall out of the algebra: V / I -> resistance.\n",
    "R = (-65. * u.mV) / (2. * u.nA)\n",
    "print('V / I =', R)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ffe9e47c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:16.472192Z",
     "iopub.status.busy": "2026-06-17T09:11:16.472023Z",
     "iopub.status.idle": "2026-06-17T09:11:16.477163Z",
     "shell.execute_reply": "2026-06-17T09:11:16.476660Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UnitMismatchError -> Cannot calculate \n",
      "-65. mV + 10. ms, because units do not match: mV != ms\n"
     ]
    }
   ],
   "source": [
    "# Incompatible: adding a voltage to a time is caught immediately.\n",
    "try:\n",
    "    bad = (-65. * u.mV) + (10. * u.ms)\n",
    "except Exception as err:\n",
    "    print(type(err).__name__, '->', err)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b8bb0d8",
   "metadata": {},
   "source": [
    "### Converting to plain numbers, on purpose\n",
    "\n",
    "When you need a raw array (for plotting, or interfacing with code that does not\n",
    "understand units), convert *explicitly* with ``.to_decimal(unit)``. Being explicit\n",
    "documents the unit you are assuming."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ea60f1b5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:16.478699Z",
     "iopub.status.busy": "2026-06-17T09:11:16.478502Z",
     "iopub.status.idle": "2026-06-17T09:11:16.482816Z",
     "shell.execute_reply": "2026-06-17T09:11:16.481949Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-65.0\n",
      "-0.065\n"
     ]
    }
   ],
   "source": [
    "v = -65. * u.mV\n",
    "print(v.to_decimal(u.mV))   # -> -65.0\n",
    "print(v.to_decimal(u.volt)) # -> -0.065"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95fd51a8",
   "metadata": {},
   "source": [
    "## Unit-safe construction\n",
    "\n",
    "Every built-in neuron, synapse, and output takes unitful parameters. Supplying them\n",
    "with units is self-documenting *and* it lets the constructor reject a quantity with\n",
    "the wrong dimension."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b1d58449",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:16.484965Z",
     "iopub.status.busy": "2026-06-17T09:11:16.484795Z",
     "iopub.status.idle": "2026-06-17T09:11:19.221512Z",
     "shell.execute_reply": "2026-06-17T09:11:19.220511Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V is stored as a unitful quantity: [0. 0. 0.] mV\n"
     ]
    }
   ],
   "source": [
    "neuron = brainpy.state.LIF(\n",
    "    100,\n",
    "    V_rest=-65. * u.mV,\n",
    "    V_th=-50. * u.mV,\n",
    "    V_reset=-65. * u.mV,\n",
    "    tau=10. * u.ms,\n",
    "    R=1. * u.ohm,\n",
    ")\n",
    "brainstate.nn.init_all_states(neuron)\n",
    "print('V is stored as a unitful quantity:', neuron.V.value[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40b626fb",
   "metadata": {},
   "source": [
    "## Unit-aware initializers\n",
    "\n",
    "State is rarely a single constant — membrane potentials might be drawn from a\n",
    "distribution, weights from a Kaiming scheme. ``braintools.init`` initializers take a\n",
    "``unit=`` argument so the array they produce carries the right dimension. Pass the\n",
    "*initializer object* to the model; it is invoked with the correct shape during\n",
    "``init_all_states``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "119c4526",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:19.224040Z",
     "iopub.status.busy": "2026-06-17T09:11:19.223537Z",
     "iopub.status.idle": "2026-06-17T09:11:19.644536Z",
     "shell.execute_reply": "2026-06-17T09:11:19.643696Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean initial V: -55.00637 mV\n",
      "initial g unit matches mS: [0. 0. 0.] mS\n"
     ]
    }
   ],
   "source": [
    "# Heterogeneous initial voltages, in millivolts.\n",
    "neuron = brainpy.state.LIF(\n",
    "    1000,\n",
    "    V_rest=-65. * u.mV, V_th=-50. * u.mV, V_reset=-65. * u.mV, tau=10. * u.ms,\n",
    "    V_initializer=braintools.init.Normal(-55., 2., unit=u.mV),\n",
    ")\n",
    "brainstate.nn.init_all_states(neuron)\n",
    "print('mean initial V:', u.math.mean(neuron.V.value))\n",
    "\n",
    "# Synaptic conductances initialized to zero, with units.\n",
    "syn = brainpy.state.Expon(\n",
    "    1000, tau=5. * u.ms,\n",
    "    g_initializer=braintools.init.Constant(0. * u.mS),\n",
    ")\n",
    "brainstate.nn.init_all_states(syn)\n",
    "print('initial g unit matches mS:', syn.g.value[:3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb106c5e",
   "metadata": {},
   "source": [
    "Common initializers you'll reach for:\n",
    "\n",
    "- ``braintools.init.Constant(value)`` — a fixed value (often ``0. * u.mS``).\n",
    "- ``braintools.init.Normal(mean, std, unit=...)`` — Gaussian, e.g. heterogeneous\n",
    "  ``V``.\n",
    "- ``braintools.init.Uniform(low, high, unit=...)`` — uniform over a range.\n",
    "- ``braintools.init.KaimingNormal(unit=...)`` — weight initialization for trainable\n",
    "  layers (see {doc}`/concepts/differentiability`).\n",
    "- ``braintools.init.ZeroInit(unit=...)`` — zeros with a unit, e.g. a bias in\n",
    "  ``u.mA``."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75b407a8",
   "metadata": {},
   "source": [
    "## Pitfalls the unit system catches at construction\n",
    "\n",
    "These mistakes would be invisible in a unitless framework and produce subtly (or\n",
    "catastrophically) wrong dynamics. Here they fail loudly and early."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bb2160f2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:19.646246Z",
     "iopub.status.busy": "2026-06-17T09:11:19.646028Z",
     "iopub.status.idle": "2026-06-17T09:11:19.649416Z",
     "shell.execute_reply": "2026-06-17T09:11:19.648878Z"
    }
   },
   "outputs": [],
   "source": [
    "# Pitfall 1: a time constant given in the wrong dimension.\n",
    "try:\n",
    "    brainpy.state.LIF(10, tau=10. * u.mV)   # mV where ms is required\n",
    "except Exception as err:\n",
    "    print('caught wrong-dimension tau:', type(err).__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a400e7d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-17T09:11:19.651224Z",
     "iopub.status.busy": "2026-06-17T09:11:19.651073Z",
     "iopub.status.idle": "2026-06-17T09:11:19.655452Z",
     "shell.execute_reply": "2026-06-17T09:11:19.654555Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6 mS + 50 nS = 0.60005 mS (not 50.6)\n"
     ]
    }
   ],
   "source": [
    "# Pitfall 2: mixing conductance scales. mS and nS differ by 1e6;\n",
    "# the algebra keeps them straight instead of silently adding magnitudes.\n",
    "g_total = 0.6 * u.mS + 50. * u.nS\n",
    "print('0.6 mS + 50 nS =', g_total.to_decimal(u.mS), 'mS (not 50.6)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "277950d4",
   "metadata": {},
   "source": [
    "## Recap\n",
    "\n",
    "- Quantities are **magnitude × unit**; arithmetic is dimensionally checked and\n",
    "  derived units fall out automatically.\n",
    "- Construct models with **unitful parameters**; convert to plain numbers only\n",
    "  **explicitly**, via ``.to_decimal(unit)``.\n",
    "- **Unit-aware initializers** (``braintools.init.*(..., unit=...)``) attach the\n",
    "  right dimension to initial state.\n",
    "- The unit system turns whole categories of modeling error into exceptions raised\n",
    "  at construction time.\n",
    "\n",
    "## See also\n",
    "\n",
    "- {doc}`/concepts/state-paradigm` — the states these units live in.\n",
    "- {doc}`/concepts/model-anatomy` — where unitful parameters enter neurons and\n",
    "  synapses.\n",
    "- {doc}`/concepts/differentiability` — unitful weight initializers for trainable\n",
    "  layers."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
