{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Math Operations\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/physical_units/math_operations_with_quantity.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/physical_units/math_operations_with_quantity.ipynb)"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Like Numpy and Jax numpy, arithmetic operators on arrays apply elementwise."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "a = [20, 30, 40, 50] * u.mV\n",
    "b = jnp.arange(4) * u.mV\n",
    "b"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Addition and Subtraction"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Addition and subtraction of quantities need to have the same units and keep the units in the result."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "c = a - b\n",
    "c"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "c + b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Multiplication and Division"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Multiplication and division of quantities multiply and divide the values and add and subtract the dimensions of the units."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "A = jnp.array([[1, 2], [3, 4]]) * u.mV\n",
    "B = jnp.array([[5, 6], [7, 8]]) * u.mV\n",
    "\n",
    "A, B"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A * B # element-wise multiplication",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A @ B # matrix multiplication",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A.dot(B) # matrix multiplication",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A / 2 # divide by a scalar",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "if the unit of result is unitless, the unit is removed and returned as jax.Array"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A / (2 * u.mV) # divide by a quantity, return jax array",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A / (2 * u.mA) # divide by a quantity, return quantity",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Power"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "The power operator raises the value of the quantity to the power of the scalar, and multiplies the unit by the scalar."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "A ** 2 # element-wise power",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Built-in Functions"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "brainunit provides a number of built-in functions in `Quantity` class to perform operations on quantities. These functions are:\n",
    "- unary operations\n",
    "    - positive(+)\n",
    "    - negative(-)\n",
    "    - absolute(abs)\n",
    "    - invert(~)\n",
    "- logical operations\n",
    "    - all\n",
    "    - any\n",
    "- shape operations\n",
    "    - reshape\n",
    "    - resize\n",
    "    - squeeze\n",
    "    - unsqueeze\n",
    "    - spilt\n",
    "    - swapaxes\n",
    "    - transpose\n",
    "    - ravel\n",
    "    - take\n",
    "    - repeat\n",
    "    - diagonal\n",
    "    - trace\n",
    "- mathematical functions\n",
    "    - nonzero\n",
    "    - argmax\n",
    "    - argmin\n",
    "    - argsort\n",
    "    - var\n",
    "    - round\n",
    "    - std\n",
    "    - sum\n",
    "    - cumsum\n",
    "    - cumprod\n",
    "    - max\n",
    "    - mean\n",
    "    - min\n",
    "    - ptp\n",
    "    - clip\n",
    "    - conj\n",
    "    - dot\n",
    "    - fill\n",
    "    - item\n",
    "    - prod\n",
    "    - clamp\n",
    "    - sort\n",
    "\n",
    "For more details on these functions, refer to the [documentation](https://brainunit.readthedocs.io/apis/generated/brainunit.Quantity.html)."
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Indexing, Slicing and Iterating"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "One-dimensional Quantity can be indexed, sliced and iterated over, much like lists and other Python sequences."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "a = jnp.arange(10) ** 3 * u.mV\n",
    "a"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "a[2]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "a[2:5]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Only same dimension Quantity can be set to a slice of a Quantity."
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# equivalent to a[0:6:2] = 1000;\n",
    "# from start to position 6, exclusive, set every 2nd element to 1000\n",
    "a[:6:2] = 1000 * u.mV\n",
    "a"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "a[::-1] # reversed a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "for i in a:\n",
    "    print(i**(1 / 3.))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Multidimensional Quantity can have one index per axis. These indices are given in a tuple separated by commas:"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def f(x, y):\n",
    "    return 10 * x + y\n",
    "b = jnp.fromfunction(f, (5, 4), dtype=jnp.int32) * u.mV\n",
    "b"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "b[2, 3]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "b[0:5, 1]  # each row in the second column of b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "b[:, 1]  # equivalent to the previous example",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "b[1:3, :]  # each column in the second and third row of b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "When fewer indices are provided than the number of axes, the missing indices are considered complete slices:"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "b[-1]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "The expression within brackets in b[i] is treated as an i followed by as many instances of : as needed to represent the remaining axes. NumPy also allows you to write this using dots as b[i, ...].\n",
    "\n",
    "The dots (...) represent as many colons as needed to produce a complete indexing tuple. For example, if x is a Quantity with 5 axes, then\n",
    "- x[1, 2, ...] is equivalent to x[1, 2, :, :, :],\n",
    "- x[..., 3] to x[:, :, :, :, 3] and\n",
    "- x[4, ..., 5, :] to x[4, :, :, 5, :]."
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "c = jnp.array([[[0, 1, 2], [10, 12, 13]], [[100, 101, 102], [110, 112, 113]]]) * u.mV # a 3D array (two stacked 2D arrays)\n",
    "c.shape"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "c[1, ...] # same as c[1, :, :] or c[1]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "c[..., 2] # same as c[:, :, 2]",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Iterating over multidimensional Quantity is done with respect to the first axis:"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "for row in b:\n",
    "    print(row)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Operating on Subsets\n",
    "\n",
    "`.at` method can be used to operate on a subset of the Quantity. The following are examples of operating on subsets of a Quantity:"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "q = jnp.arange(5.0) * u.mV\n",
    "q"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[2].add(10 * u.mV)",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[10].add(10 * u.mV)  # out-of-bounds indices are ignored",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[20].add(10 * u.mV, mode='clip') # out-of-bounds indices are clipped",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[2].get()",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[20].get()  # out-of-bounds indices clipped",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "brainunit will check the consistency of operations on units and raise an error for dimensionality mismatches:"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "try:\n",
    "    q.at[2].add(10)\n",
    "except Exception as e:\n",
    "    print(e)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "brainunit also allows customized fill values for the `at` method:"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "q.at[20].get(mode='fill', fill_value=-1 * u.mV)  # custom fill value",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "try:\n",
    "    q.at[20].get(mode='fill', fill_value=-1)\n",
    "except Exception as e:\n",
    "    print(e)"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "brainpy-dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
