{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Checking Function Units\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/check_units.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/check_units.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "In scientific computing, it is crucial to ensure that function parameters and return values have the correct units. To streamline this process, we can use the `brainunit.check_units` decorator to validate the units of function parameters and return values."
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to import the necessary libraries and modules."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.762968600Z",
     "start_time": "2026-03-04T15:10:35.941873600Z"
    }
   },
   "source": [
    "import brainunit\n",
    "from brainunit import volt, mV, meter, second, check_dims, check_units, assign_units, DimensionMismatchError, UnitMismatchError"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checking Units"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `check_dims` Decorator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `check_dims` decorator is used to validate the dimensions of input arguments or return values of a function. It ensures that the dimensions match the expected dimensions, helping to avoid errors caused by unit mismatches.\n",
    "\n",
    "We will demonstrate the usage of `check_dims` through several examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Basic Usage\n",
    "We can use the `check_dims` decorator to validate whether the input arguments of a function have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.795142Z",
     "start_time": "2026-03-04T15:10:36.781640700Z"
    }
   },
   "source": [
    "@check_dims(v=volt.dim)\n",
    "def a_function(v, x):\n",
    "    \"\"\"\n",
    "    v must have units of volt, and x can have any (or no) unit.\n",
    "    \"\"\"\n",
    "    pass"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Dimensions\n",
    "The following calls are correct because the `v` argument has units of volt or are `strings` or `None`:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.199737400Z",
     "start_time": "2026-03-04T15:10:36.796373400Z"
    }
   },
   "source": [
    "a_function(3 * mV, 5 * second)\n",
    "a_function(5 * volt, \"something\")\n",
    "a_function([1, 2, 3] * volt, None)\n",
    "a_function([1 * volt, 2 * volt, 3 * volt], None)\n",
    "a_function(\"a string\", None)\n",
    "a_function(None, None)"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Units\n",
    "The following calls will raise a `DimensionMismatchError` because the `v` argument does not have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.517901100Z",
     "start_time": "2026-03-04T15:10:37.302738300Z"
    }
   },
   "source": [
    "try:\n",
    "    a_function(5 * second, None)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)\n",
    "    \n",
    "try:\n",
    "    a_function(5, None)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)\n",
    "    \n",
    "try:\n",
    "    a_function(object(), None)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '5 s' (unit is s).\n",
      "Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '5' (unit is 1).\n",
      "Function 'a_function' expected a array with dimension metre ** 2 * kilogram * second ** -3 * amp ** -1 for argument 'v' but got '<object object at 0x000001F4EC3DF6C0>' (unit is 1).\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Return Values\n",
    "\n",
    "The `check_dims` decorator can also be used to validate whether the return value of a function has the expected dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.578183600Z",
     "start_time": "2026-03-04T15:10:37.548253300Z"
    }
   },
   "source": [
    "@check_dims(result=second.dim)\n",
    "def b_function(return_second):\n",
    "    \"\"\"\n",
    "    If return_second is True, return a value in seconds; otherwise, return a value in volts.\n",
    "    \"\"\"\n",
    "    if return_second:\n",
    "        return 5 * second\n",
    "    else:\n",
    "        return 3 * volt"
   ],
   "outputs": [],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Value\n",
    "The following call is correct because the return value has dimensions of seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.614063200Z",
     "start_time": "2026-03-04T15:10:37.596529500Z"
    }
   },
   "source": [
    "b_function(True)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(5, \"s\")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Value\n",
    "The following call will raise a `DimensionMismatchError` because the return value has dimensions of volts instead of seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.637562800Z",
     "start_time": "2026-03-04T15:10:37.617524Z"
    }
   },
   "source": [
    "try:\n",
    "    b_function(False)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The return value of function 'b_function' was expected to have dimension s but was '3 V' (unit is m^2 kg s^-3 A^-1).\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Multiple Return Values\n",
    "\n",
    "The `check_dims` decorator can also validate multiple return values to ensure they have the expected dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.664242Z",
     "start_time": "2026-03-04T15:10:37.645612900Z"
    }
   },
   "source": [
    "@check_dims(result=(second.dim, volt.dim))\n",
    "def d_function(true_result):\n",
    "    \"\"\"\n",
    "    If true_result is True, return values in seconds and volts; otherwise, return values in volts and seconds.\n",
    "    \"\"\"\n",
    "    if true_result:\n",
    "        return 5 * second, 3 * volt\n",
    "    else:\n",
    "        return 3 * volt, 5 * second"
   ],
   "outputs": [],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "\n",
    "The following call is correct because the return values have dimensions of seconds and volts, respectively."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.686431900Z",
     "start_time": "2026-03-04T15:10:37.666650500Z"
    }
   },
   "source": [
    "d_function(True)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Quantity(5, \"s\"), Quantity(3, \"V\"))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following call will raise a `DimensionMismatchError` because the return values are in volts and seconds, which do not match the expected order."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.701130600Z",
     "start_time": "2026-03-04T15:10:37.687440400Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function(False)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The return value of function 'd_function' was expected to have dimension s but was '3 V' (unit is m^2 kg s^-3 A^-1).\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Dictionary Return Values\n",
    "The `check_dims` decorator can also validate dictionary return values to ensure they have the expected dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.746893700Z",
     "start_time": "2026-03-04T15:10:37.702130900Z"
    }
   },
   "source": [
    "@check_dims(result={'u': second.dim, 'v': (volt.dim, meter.dim)})\n",
    "def d_function2(true_result):\n",
    "    \"\"\"\n",
    "    Return different dictionary results based on the value of true_result.\n",
    "    \"\"\"\n",
    "    if true_result == 0:\n",
    "        return {'u': 5 * second, 'v': (3 * volt, 2 * meter)}\n",
    "    elif true_result == 1:\n",
    "        return 3 * volt, 5 * second\n",
    "    else:\n",
    "        return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}"
   ],
   "outputs": [],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values match the expected dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.783734300Z",
     "start_time": "2026-03-04T15:10:37.747889300Z"
    }
   },
   "source": [
    "d_function2(0)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'u': Quantity(5, \"s\"), 'v': (Quantity(3, \"V\"), Quantity(2, \"m\"))}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following calls will raise a `TypeError` or `DimensionMismatchError` because the return values do not match the expected dimensions."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.819808600Z",
     "start_time": "2026-03-04T15:10:37.805511100Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function2(1)\n",
    "except TypeError as e:\n",
    "    print(e)\n",
    "try:\n",
    "    d_function2(2)\n",
    "except DimensionMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected a return value of type {'u': second, 'v': (metre ** 2 * kilogram * second ** -3 * amp ** -1, metre)} but got (Quantity(3, \"V\"), Quantity(5, \"s\"))\n",
      "The return value of function 'd_function2' was expected to have dimension m but was '2 V' (unit is m^2 kg s^-3 A^-1).\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `check_units` Decorator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `check_units` decorator is used to validate the dimensions of input arguments or return values of a function. It ensures that the dimensions match the expected dimensions, helping to avoid errors caused by unit mismatches.\n",
    "\n",
    "We will demonstrate the usage of `check_units` through several examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Basic Usage\n",
    "We can use the `check_units` decorator to validate whether the input arguments of a function have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.832187900Z",
     "start_time": "2026-03-04T15:10:37.820805700Z"
    }
   },
   "source": [
    "@check_units(v=volt)\n",
    "def a_function(v, x):\n",
    "    \"\"\"\n",
    "    v must have units of volt, and x can have any (or no) unit.\n",
    "    \"\"\"\n",
    "    pass"
   ],
   "outputs": [],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Dimensions\n",
    "The following calls are correct because the `v` argument has units of volt or are `strings` or `None`:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.857961300Z",
     "start_time": "2026-03-04T15:10:37.832187900Z"
    }
   },
   "source": [
    "a_function(3 * volt, 5 * second)\n",
    "a_function(5 * volt, \"something\")\n",
    "a_function([1, 2, 3] * volt, None)\n",
    "# lists that can be converted should also work\n",
    "a_function([1 * volt, 2 * volt, 3 * volt], None)\n",
    "# Strings and None are also allowed to pass\n",
    "a_function(\"a string\", None)\n",
    "a_function(None, None)"
   ],
   "outputs": [],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Units\n",
    "The following calls will raise a `DimensionMismatchError` because the `v` argument does not have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.892803700Z",
     "start_time": "2026-03-04T15:10:37.871404800Z"
    }
   },
   "source": [
    "try:\n",
    "    a_function(5 * second, None)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)\n",
    "    \n",
    "try:\n",
    "    a_function(5, None)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)\n",
    "    \n",
    "try:\n",
    "    a_function(object(), None)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Function 'a_function' expected a array with unit Unit(\"V\") for argument 'v' but got '5 s' (unit is s).\n",
      "Function 'a_function' expected a array with unit Unit(\"V\") for argument 'v' but got '5' (unit is 1).\n",
      "Function 'a_function' expected a array with unit Unit(\"V\") for argument 'v' but got '<object object at 0x000001F4EC3DF4B0>' (unit is 1).\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Return Values\n",
    "\n",
    "The `check_units` decorator can also be used to validate whether the return value of a function has the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.907751300Z",
     "start_time": "2026-03-04T15:10:37.898080400Z"
    }
   },
   "source": [
    "@check_units(result=second)\n",
    "def b_function(return_second):\n",
    "    \"\"\"\n",
    "    Return a value in seconds if return_second is True, otherwise return\n",
    "    a value in volt.\n",
    "    \"\"\"\n",
    "    if return_second:\n",
    "        return 5 * second\n",
    "    else:\n",
    "        return 3 * volt"
   ],
   "outputs": [],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Value\n",
    "The following call is correct because the return value has units of seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.923402100Z",
     "start_time": "2026-03-04T15:10:37.907751300Z"
    }
   },
   "source": [
    "b_function(True)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity(5, \"s\")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Value\n",
    "The following call will raise a `UnitMismatchError` because the return value has units of volts instead of seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.941484900Z",
     "start_time": "2026-03-04T15:10:37.924423Z"
    }
   },
   "source": [
    "try:\n",
    "    b_function(False)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The return value of function 'b_function' was expected to have unit s but got unit V (value: Quantity(3, \"V\")) (units are s and V).\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Multiple Return Values\n",
    "\n",
    "The `check_units` decorator can also validate multiple return values to ensure they have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.956427200Z",
     "start_time": "2026-03-04T15:10:37.944929400Z"
    }
   },
   "source": [
    "@check_units(result=(second, volt))\n",
    "def d_function(true_result):\n",
    "    \"\"\"\n",
    "    Return a value in seconds if return_second is True, otherwise return\n",
    "    a value in volt.\n",
    "    \"\"\"\n",
    "    if true_result:\n",
    "        return 5 * second, 3 * volt\n",
    "    else:\n",
    "        return 3 * volt, 5 * second"
   ],
   "outputs": [],
   "execution_count": 20
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "\n",
    "The following call is correct because the return values have units of seconds and volts, respectively."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.971764800Z",
     "start_time": "2026-03-04T15:10:37.957677900Z"
    }
   },
   "source": [
    "d_function(True)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Quantity(5, \"s\"), Quantity(3, \"V\"))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following call will raise a `UnitMismatchError` because the return values are in volts and seconds, which do not match the expected order."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:37.990799Z",
     "start_time": "2026-03-04T15:10:37.972755500Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function(False)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The return value of function 'd_function' was expected to have unit s but got unit V (value: Quantity(3, \"V\")) (units are s and V).\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Validating Dictionary Return Values\n",
    "The `check_units` decorator can also validate dictionary return values to ensure they have the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.027897Z",
     "start_time": "2026-03-04T15:10:37.994810300Z"
    }
   },
   "source": [
    "@check_units(result={'u': second, 'v': (volt, meter)})\n",
    "def d_function2(true_result):\n",
    "    \"\"\"\n",
    "    Return a value in seconds if return_second is True, otherwise return\n",
    "    a value in volt.\n",
    "    \"\"\"\n",
    "    if true_result == 0:\n",
    "        return {'u': 5 * second, 'v': (3 * volt, 2 * meter)}\n",
    "    elif true_result == 1:\n",
    "        return 3 * volt, 5 * second\n",
    "    else:\n",
    "        return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}"
   ],
   "outputs": [],
   "execution_count": 23
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values match the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.169526200Z",
     "start_time": "2026-03-04T15:10:38.032120400Z"
    }
   },
   "source": [
    "d_function2(0)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'u': Quantity(5, \"s\"), 'v': (Quantity(3, \"V\"), Quantity(2, \"m\"))}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following calls will raise a `TypeError` or `UnitMismatchError` because the return values do not match the expected units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.246052700Z",
     "start_time": "2026-03-04T15:10:38.173522700Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function2(1)\n",
    "except TypeError as e:\n",
    "    print(e)\n",
    "try:\n",
    "    d_function2(2)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected a return value of type {'u': Unit(\"s\"), 'v': (Unit(\"V\"), Unit(\"m\"))} but got (Quantity(3, \"V\"), Quantity(5, \"s\"))\n",
      "The return value of function 'd_function2' was expected to have unit m but got unit V (value: Quantity(2, \"V\")) (units are m and V).\n"
     ]
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Assigning Units"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `assign_units` Decorator\n",
    "The `assign_units` decorator is used to automatically assign units to the input arguments or return values of a function. It ensures that the values are converted to the specified units, simplifying unit handling in scientific computations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Basic Usage\n",
    "\n",
    "We can use the `assign_units` decorator to automatically assign units to the input arguments of a function."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.415539900Z",
     "start_time": "2026-03-04T15:10:38.295085700Z"
    }
   },
   "source": [
    "@assign_units(v=volt)\n",
    "def a_function(v, x):\n",
    "    \"\"\"\n",
    "    v will be assigned units of volt, and x can have any (or no) unit.\n",
    "    \"\"\"\n",
    "    return v"
   ],
   "outputs": [],
   "execution_count": 26
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Units\n",
    "The following calls are correct because the `v` argument is automatically converted to volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.527247900Z",
     "start_time": "2026-03-04T15:10:38.450925Z"
    }
   },
   "source": [
    "assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)\n",
    "assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)\n",
    "assert a_function(5 * volt, \"something\") == (5 * volt).to_decimal(volt)"
   ],
   "outputs": [],
   "execution_count": 27
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Units\n",
    "The following calls will raise a `UnitMismatchError` or `TypeError` because the `v` argument cannot be converted to volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.653846900Z",
     "start_time": "2026-03-04T15:10:38.538432500Z"
    }
   },
   "source": [
    "try:\n",
    "    a_function(5 * second, None)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)\n",
    "\n",
    "try:\n",
    "    a_function(5, None)\n",
    "except TypeError as e:\n",
    "    print(e)\n",
    "\n",
    "try:\n",
    "    a_function(object(), None)\n",
    "except TypeError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cannot convert to the decimal number using a unit with different dimensions. (units are s and V).\n",
      "Function 'a_function' expected a Quantity object for argument 'v' but got '5'\n",
      "Function 'a_function' expected a Quantity object for argument 'v' but got '<object object at 0x000001F4EC3DF6C0>'\n"
     ]
    }
   ],
   "execution_count": 28
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Return Values\n",
    "The `assign_units` decorator can also be used to automatically assign units to the return value of a function."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.731045700Z",
     "start_time": "2026-03-04T15:10:38.681373Z"
    }
   },
   "source": [
    "@assign_units(result=second)\n",
    "def b_function():\n",
    "    \"\"\"\n",
    "    The return value will be assigned units of seconds.\n",
    "    \"\"\"\n",
    "    return 5"
   ],
   "outputs": [],
   "execution_count": 29
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Value\n",
    "The following call is correct because the return value is automatically converted to seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.742681500Z",
     "start_time": "2026-03-04T15:10:38.732514500Z"
    }
   },
   "source": [
    "assert b_function() == 5 * second"
   ],
   "outputs": [],
   "execution_count": 30
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Multiple Return Values\n",
    "The `assign_units` decorator can also assign units to multiple return values."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.756119Z",
     "start_time": "2026-03-04T15:10:38.742681500Z"
    }
   },
   "source": [
    "@assign_units(result=(second, volt))\n",
    "def d_function():\n",
    "    \"\"\"\n",
    "    The return values will be assigned units of seconds and volts, respectively.\n",
    "    \"\"\"\n",
    "    return 5, 3"
   ],
   "outputs": [],
   "execution_count": 31
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values are automatically converted to seconds and volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.814000800Z",
     "start_time": "2026-03-04T15:10:38.758129Z"
    }
   },
   "source": [
    "assert d_function()[0] == 5 * second\n",
    "assert d_function()[1] == 3 * volt"
   ],
   "outputs": [],
   "execution_count": 32
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Dictionary Return Values\n",
    "The `assign_units` decorator can also assign units to dictionary return values."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.824344600Z",
     "start_time": "2026-03-04T15:10:38.814989900Z"
    }
   },
   "source": [
    "@assign_units(result={'u': second, 'v': (volt, meter)})\n",
    "def d_function2(true_result):\n",
    "    \"\"\"\n",
    "    The return values will be assigned units based on the dictionary specification.\n",
    "    \"\"\"\n",
    "    if true_result == 0:\n",
    "        return {'u': 5, 'v': (3, 2)}\n",
    "    elif true_result == 1:\n",
    "        return 3, 5\n",
    "    else:\n",
    "        return 3, 5"
   ],
   "outputs": [],
   "execution_count": 33
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values are automatically converted to the specified units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.836326900Z",
     "start_time": "2026-03-04T15:10:38.825345700Z"
    }
   },
   "source": [
    "d_function2(0)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'u': Quantity(5, \"s\"), 'v': (Quantity(3, \"V\"), Quantity(2, \"m\"))}"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 34
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following call will raise a `TypeError` because the return values do not match the expected structure."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:38.847523300Z",
     "start_time": "2026-03-04T15:10:38.837339900Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function2(1)\n",
    "except TypeError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected a return value of pytree PyTreeDef({'u': *, 'v': (*, *)}) with type {'u': Unit(\"s\"), 'v': (Unit(\"V\"), Unit(\"m\"))}, but got the pytree PyTreeDef((*, *)) and the value (3, 5)\n"
     ]
    }
   ],
   "execution_count": 35
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Through the examples above, we can see the utility of the `assign_units` decorator in automatically assigning units to input arguments and return values. It simplifies unit handling in scientific computations, ensuring consistency and reducing the likelihood of errors."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainpy-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
