{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Assigning Function Units\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/assign_units.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/assign_units.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In scientific computing, a substantial number of existing scientific computing functions are designed based on dimensionless data. ``brainunit`` provides interface that is applicable to these dimensionless functions without modifying existing frameworks or underlying implementations. The core idea is:\n",
    "\n",
    "- Dimensionless processing before function calls: Prior to invoking scientific computing functions, input data undergoes dimensionless processing to ensure that the functions internally handle only unitless numerical operations. For example, using ``b = a.to_decimal(UNIT)`` method can normalize the quantity ``a`` into the dimensionless data ``b`` according to the given physical unit ``UNIT``.\n",
    "-  Restoring physical units after computation: Once the computation is complete and results are returned, we can restore the appropriate physical units to the solutions.\n",
    "\n",
    "Specifically, `brainunit` provides the ``assign_units`` function, which facilitates the automatic assignment and restoration of physical units at the input and output stages of functions. This method is, in principle, applicable to any Python-based scientific computing library, preserving the physical semantics and interpretability at the input and output levels without altering their existing structures."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we need to import the necessary libraries and modules."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.409445800Z",
     "start_time": "2026-03-04T15:10:33.532013300Z"
    }
   },
   "source": [
    "import brainunit\n",
    "from brainunit import volt, mV, meter, second, check_dims, check_units, assign_units, DimensionMismatchError, UnitMismatchError"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `assign_units` Decorator\n",
    "The `assign_units` decorator is used to automatically assign units to the input arguments or return values of a function. It ensures that the values are converted to the specified units, simplifying unit handling in scientific computations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Basic Usage\n",
    "\n",
    "We can use the `assign_units` decorator to automatically assign units to the input arguments of a function."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.420733300Z",
     "start_time": "2026-03-04T15:10:34.409445800Z"
    }
   },
   "source": [
    "@assign_units(v=volt)\n",
    "def a_function(v, x):\n",
    "    \"\"\"\n",
    "    v will be assigned units of volt, and x can have any (or no) unit.\n",
    "    \"\"\"\n",
    "    return v"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Units\n",
    "The following calls are correct because the `v` argument is automatically converted to volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.434423Z",
     "start_time": "2026-03-04T15:10:34.421724400Z"
    }
   },
   "source": [
    "assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)\n",
    "assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)\n",
    "assert a_function(5 * volt, \"something\") == (5 * volt).to_decimal(volt)"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Units\n",
    "The following calls will raise a `UnitMismatchError` or `TypeError` because the `v` argument cannot be converted to volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.447250600Z",
     "start_time": "2026-03-04T15:10:34.435545600Z"
    }
   },
   "source": [
    "try:\n",
    "    a_function(5 * second, None)\n",
    "except UnitMismatchError as e:\n",
    "    print(e)\n",
    "\n",
    "try:\n",
    "    a_function(5, None)\n",
    "except TypeError as e:\n",
    "    print(e)\n",
    "\n",
    "try:\n",
    "    a_function(object(), None)\n",
    "except TypeError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cannot convert to the decimal number using a unit with different dimensions. (units are s and V).\n",
      "Function 'a_function' expected a Quantity object for argument 'v' but got '5'\n",
      "Function 'a_function' expected a Quantity object for argument 'v' but got '<object object at 0x0000022CE1BAF550>'\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Return Values\n",
    "The `assign_units` decorator can also be used to automatically assign units to the return value of a function."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.472331600Z",
     "start_time": "2026-03-04T15:10:34.448265600Z"
    }
   },
   "source": [
    "@assign_units(result=second)\n",
    "def b_function():\n",
    "    \"\"\"\n",
    "    The return value will be assigned units of seconds.\n",
    "    \"\"\"\n",
    "    return 5"
   ],
   "outputs": [],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Value\n",
    "The following call is correct because the return value is automatically converted to seconds."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.512803800Z",
     "start_time": "2026-03-04T15:10:34.474343300Z"
    }
   },
   "source": [
    "assert b_function() == 5 * second"
   ],
   "outputs": [],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Multiple Return Values\n",
    "The `assign_units` decorator can also assign units to multiple return values."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.584606100Z",
     "start_time": "2026-03-04T15:10:34.516023100Z"
    }
   },
   "source": [
    "@assign_units(result=(second, volt))\n",
    "def d_function():\n",
    "    \"\"\"\n",
    "    The return values will be assigned units of seconds and volts, respectively.\n",
    "    \"\"\"\n",
    "    return 5, 3"
   ],
   "outputs": [],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values are automatically converted to seconds and volts."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.596965900Z",
     "start_time": "2026-03-04T15:10:34.586050900Z"
    }
   },
   "source": [
    "assert d_function()[0] == 5 * second\n",
    "assert d_function()[1] == 3 * volt"
   ],
   "outputs": [],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Assigning Units to Dictionary Return Values\n",
    "The `assign_units` decorator can also assign units to dictionary return values."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.739650600Z",
     "start_time": "2026-03-04T15:10:34.597982400Z"
    }
   },
   "source": [
    "@assign_units(result={'u': second, 'v': (volt, meter)})\n",
    "def d_function2(true_result):\n",
    "    \"\"\"\n",
    "    The return values will be assigned units based on the dictionary specification.\n",
    "    \"\"\"\n",
    "    if true_result == 0:\n",
    "        return {'u': 5, 'v': (3, 2)}\n",
    "    elif true_result == 1:\n",
    "        return 3, 5\n",
    "    else:\n",
    "        return 3, 5"
   ],
   "outputs": [],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Correct Return Values\n",
    "The following call is correct because the return values are automatically converted to the specified units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.207244700Z",
     "start_time": "2026-03-04T15:10:34.864228Z"
    }
   },
   "source": [
    "d_function2(0)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'u': Quantity(5, \"s\"), 'v': (Quantity(3, \"V\"), Quantity(2, \"m\"))}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Incorrect Return Values\n",
    "The following call will raise a `TypeError` because the return values do not match the expected structure."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.314564300Z",
     "start_time": "2026-03-04T15:10:35.229875600Z"
    }
   },
   "source": [
    "try:\n",
    "    d_function2(1)\n",
    "except TypeError as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected a return value of pytree PyTreeDef({'u': *, 'v': (*, *)}) with type {'u': Unit(\"s\"), 'v': (Unit(\"V\"), Unit(\"m\"))}, but got the pytree PyTreeDef((*, *)) and the value (3, 5)\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Through the examples above, we can see the utility of the `assign_units` decorator in automatically assigning units to input arguments and return values. It simplifies unit handling in scientific computations, ensuring consistency and reducing the likelihood of errors."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainpy-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
