{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NumPy Functions\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/numpy_functions.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/numpy_functions.ipynb)\n",
    "\n",
    "`brainunit.math` provides 500+ unit-aware functions compatible with NumPy/JAX.\n",
    "They are categorized by how they handle units:\n",
    "\n",
    "1. **Array Creation** — create arrays with or without units\n",
    "2. **Functions Accepting Unitless** — require dimensionless input (trig, exp, log)\n",
    "3. **Functions Changing Unit** — output unit differs from input (multiply, sqrt, dot)\n",
    "4. **Functions Keeping Unit** — output unit matches input (sort, sum, reshape)\n",
    "5. **Functions Removing Unit** — return dimensionless results (comparisons, argmax)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:49.033177600Z",
     "start_time": "2026-03-04T15:10:48.014202800Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Array Creation\n",
    "\n",
    "Functions for creating arrays with units.\n",
    "\n",
    "Includes: `array`, `asarray`, `zeros`, `ones`, `full`, `eye`, `arange`, `linspace`, `logspace`, `meshgrid`, and their `_like` variants."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.276236Z",
     "start_time": "2026-03-04T15:10:49.067938600Z"
    }
   },
   "source": [
    "# zeros and ones with unit\n",
    "print('zeros:', u.math.zeros(3, unit=u.volt))\n",
    "print('ones:', u.math.ones((2, 2), unit=u.meter))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "zeros: [0. 0. 0.] V\n",
      "ones: [[1. 1.]\n",
      " [1. 1.]] m\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.433808600Z",
     "start_time": "2026-03-04T15:10:50.319709900Z"
    }
   },
   "source": [
    "# arange with Quantity endpoints\n",
    "times = u.math.arange(0 * u.ms, 10 * u.ms, step=2 * u.ms)\n",
    "print('arange:', times)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arange: [0 2 4 6 8] ms\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.680392100Z",
     "start_time": "2026-03-04T15:10:50.434799700Z"
    }
   },
   "source": [
    "# linspace between Quantity endpoints\n",
    "voltages = u.math.linspace(0 * u.mV, 100 * u.mV, 5)\n",
    "print('linspace:', voltages)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "linspace: [  0.  25.  50.  75. 100.] mV\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:51.217818600Z",
     "start_time": "2026-03-04T15:10:50.686777300Z"
    }
   },
   "source": [
    "# eye with unit\n",
    "print('eye:')\n",
    "print(u.math.eye(3, unit=u.ohm))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "eye:\n",
      "[[1. 0. 0.]\n",
      " [0. 1. 0.]\n",
      " [0. 0. 1.]] ohm\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:51.457004900Z",
     "start_time": "2026-03-04T15:10:51.295935200Z"
    }
   },
   "source": "# full_like: create array with same shape as a Quantity\ntemplate = jnp.array([1., 2., 3.]) * u.newton\nprint('full_like:', u.math.full_like(template, 99.0 * u.newton))",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full_like: [99. 99. 99.] N\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Functions Accepting Unitless\n",
    "\n",
    "These functions require dimensionless inputs. They perform transcendental or trigonometric operations that are not physically meaningful with units.\n",
    "\n",
    "Includes: `exp`, `log`, `sin`, `cos`, `tan`, `arcsin`, `arctan2`, `sinh`, `cosh`, `deg2rad`, `rad2deg`, `logaddexp`, and more."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:52.362309900Z",
     "start_time": "2026-03-04T15:10:51.532137900Z"
    }
   },
   "source": [
    "# Trigonometric functions on dimensionless values\n",
    "angles = jnp.array([0., jnp.pi/6, jnp.pi/4, jnp.pi/3, jnp.pi/2])\n",
    "print('sin:', u.math.sin(angles))\n",
    "print('cos:', u.math.cos(angles))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sin: [0.         0.5        0.70710677 0.86602545 1.        ]\n",
      "cos: [ 1.0000000e+00  8.6602539e-01  7.0710677e-01  4.9999997e-01\n",
      " -4.3711388e-08]\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.029833500Z",
     "start_time": "2026-03-04T15:10:52.740224500Z"
    }
   },
   "source": [
    "# Exponential and logarithm\n",
    "x = jnp.array([0., 1., 2.])\n",
    "print('exp:', u.math.exp(x))\n",
    "print('log:', u.math.log(u.math.exp(x)))  # round-trip"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "exp: [1.        2.7182817 7.389056 ]\n",
      "log: [0. 1. 2.]\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.206070500Z",
     "start_time": "2026-03-04T15:10:53.091651500Z"
    }
   },
   "source": [
    "# Passing a Quantity with units raises an error\n",
    "try:\n",
    "    u.math.exp(2.0 * u.meter)\n",
    "except Exception as e:\n",
    "    print(type(e).__name__, ':', e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TypeError : exp requires a dimensionless \"x\" when \"unit_to_scale\" is not provided. Got Quantity(unit=m, dim=m). Pass \"unit_to_scale=<Unit>\" to scale before applying exp, or convert explicitly to a dimensionless value first.\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.265774500Z",
     "start_time": "2026-03-04T15:10:53.209671300Z"
    }
   },
   "source": [
    "# Convert to dimensionless first using to_decimal\n",
    "ratio = (5.0 * u.mV).to_decimal(u.volt)  # 0.005, dimensionless\n",
    "print('exp(ratio):', u.math.exp(ratio))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "exp(ratio): 1.0050125\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.347675300Z",
     "start_time": "2026-03-04T15:10:53.266773200Z"
    }
   },
   "source": [
    "# Angle conversions\n",
    "print('deg2rad(180):', u.math.deg2rad(180.0))\n",
    "print('rad2deg(pi):', u.math.rad2deg(jnp.pi))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "deg2rad(180): 3.1415927\n",
      "rad2deg(pi): 180.0\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.422144Z",
     "start_time": "2026-03-04T15:10:53.348672400Z"
    }
   },
   "source": [
    "# arctan2 for 2D angle computation\n",
    "y = jnp.array([1., 0., -1.])\n",
    "x = jnp.array([0., 1., 0.])\n",
    "print('arctan2:', u.math.arctan2(y, x))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arctan2: [ 1.5707964  0.        -1.5707964]\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Functions Changing Unit\n",
    "\n",
    "These functions produce outputs with different units than their inputs,\n",
    "following the mathematical rules of the operation.\n",
    "\n",
    "Includes: `multiply`, `divide`, `power`, `sqrt`, `square`, `reciprocal`, `prod`, `dot`, `matmul`, `inner`, `outer`, `kron`, `cross`, `convolve`, and more."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.761693200Z",
     "start_time": "2026-03-04T15:10:53.427718600Z"
    }
   },
   "source": [
    "# multiply: units multiply\n",
    "force = jnp.array([10., 20.]) * u.newton\n",
    "distance = jnp.array([5., 3.]) * u.meter\n",
    "work = u.math.multiply(force, distance)\n",
    "print('F * d (work):', work)  # N * m = J"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F * d (work): [50. 60.] J\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:54.165954200Z",
     "start_time": "2026-03-04T15:10:53.805569900Z"
    }
   },
   "source": [
    "# divide: units divide\n",
    "print('d / t (speed):', u.math.divide(100.0 * u.meter, 10.0 * u.second))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "d / t (speed): 10. m / s\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:54.517223300Z",
     "start_time": "2026-03-04T15:10:54.265653700Z"
    }
   },
   "source": [
    "# sqrt and square\n",
    "area = jnp.array([4., 9., 16.]) * u.meter2\n",
    "print('sqrt(area):', u.math.sqrt(area))  # m^2 -> m\n",
    "\n",
    "side = jnp.array([2., 3., 4.]) * u.meter\n",
    "print('square(side):', u.math.square(side))  # m -> m^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sqrt(area): [2. 3. 4.] m\n",
      "square(side): [ 4.  9. 16.] m^2\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:55.291199200Z",
     "start_time": "2026-03-04T15:10:54.527034700Z"
    }
   },
   "source": [
    "# reciprocal\n",
    "resistance = jnp.array([10., 50.]) * u.ohm\n",
    "print('1/R (conductance):', u.math.reciprocal(resistance))  # 1/ohm = S"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/R (conductance): [0.1  0.02] S\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:55.375519Z",
     "start_time": "2026-03-04T15:10:55.326536300Z"
    }
   },
   "source": [
    "# dot product\n",
    "a = jnp.array([1., 2., 3.]) * u.meter\n",
    "b = jnp.array([4., 5., 6.]) * u.meter\n",
    "print('dot(a, b):', u.math.dot(a, b))  # m^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dot(a, b): 32. m^2\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:55.468118900Z",
     "start_time": "2026-03-04T15:10:55.375519Z"
    }
   },
   "source": [
    "# outer product\n",
    "x = jnp.array([1., 2.]) * u.volt\n",
    "y = jnp.array([3., 4., 5.]) * u.ampere\n",
    "print('outer(V, A):')\n",
    "print(u.math.outer(x, y))  # V * A = W"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "outer(V, A):\n",
      "[[ 3.  4.  5.]\n",
      " [ 6.  8. 10.]] W\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:55.998245Z",
     "start_time": "2026-03-04T15:10:55.468118900Z"
    }
   },
   "source": [
    "# prod: unit raised to the power of number of elements\n",
    "vals = jnp.array([2., 3., 4.]) * u.meter\n",
    "print('prod:', u.math.prod(vals))  # 24 m^3"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prod: 24. m^3\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Functions Keeping Unit\n",
    "\n",
    "These functions preserve the unit of the input in the output.\n",
    "\n",
    "Includes: `sum`, `mean`, `std`, `min`, `max`, `sort`, `cumsum`, `reshape`, `transpose`, `concatenate`, `stack`, `split`, `flip`, `roll`, `clip`, `abs`, `negative`, `diff`, `interp`, `where`, `unique`, and many more."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.368386800Z",
     "start_time": "2026-03-04T15:10:56.049508500Z"
    }
   },
   "source": [
    "# Statistical functions\n",
    "data = jnp.array([1., 3., 5., 7., 9.]) * u.volt\n",
    "print('sum:', u.math.sum(data))\n",
    "print('mean:', u.math.mean(data))\n",
    "print('std:', u.math.std(data))\n",
    "print('min:', u.math.min(data))\n",
    "print('max:', u.math.max(data))\n",
    "print('median:', u.math.median(data))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum: 25. V\n",
      "mean: 5. V\n",
      "std: 2.828427 V\n",
      "min: 1. V\n",
      "max: 9. V\n",
      "median: 5. V\n"
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.557535100Z",
     "start_time": "2026-03-04T15:10:56.377386800Z"
    }
   },
   "source": [
    "# Shape manipulation\n",
    "M = jnp.arange(6.).reshape(2, 3) * u.ampere\n",
    "print('original:', M)\n",
    "print('reshape (3,2):', u.math.reshape(M, (3, 2)))\n",
    "print('transpose:', u.math.transpose(M))\n",
    "print('flatten:', u.math.flatten(M))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "original: [[0. 1. 2.]\n",
      " [3. 4. 5.]] A\n",
      "reshape (3,2): [[0. 1.]\n",
      " [2. 3.]\n",
      " [4. 5.]] A\n",
      "transpose: [[0. 3.]\n",
      " [1. 4.]\n",
      " [2. 5.]] A\n",
      "flatten: [0. 1. 2. 3. 4. 5.] A\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.675282400Z",
     "start_time": "2026-03-04T15:10:56.560542200Z"
    }
   },
   "source": [
    "# Concatenation and stacking\n",
    "a = jnp.array([1., 2.]) * u.meter\n",
    "b = jnp.array([3., 4.]) * u.meter\n",
    "print('concatenate:', u.math.concatenate([a, b]))\n",
    "print('stack:', u.math.stack([a, b]))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "concatenate: [1. 2. 3. 4.] m\n",
      "stack: [[1. 2.]\n",
      " [3. 4.]] m\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.767392600Z",
     "start_time": "2026-03-04T15:10:56.675282400Z"
    }
   },
   "source": [
    "# Sorting and ordering\n",
    "unsorted = jnp.array([3., 1., 4., 1., 5.]) * u.pascal\n",
    "print('sort:', u.math.sort(unsorted))\n",
    "print('cumsum:', u.math.cumsum(unsorted))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sort: [1. 1. 3. 4. 5.] Pa\n",
      "cumsum: [ 3.  4.  8.  9. 14.] Pa\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.881222400Z",
     "start_time": "2026-03-04T15:10:56.767392600Z"
    }
   },
   "source": [
    "# abs and negative\n",
    "mixed = jnp.array([-2., 3., -1., 4.]) * u.newton\n",
    "print('abs:', u.math.abs(mixed))\n",
    "print('negative:', u.math.negative(mixed))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "abs: [2. 3. 1. 4.] N\n",
      "negative: [ 2. -3.  1. -4.] N\n"
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.930902400Z",
     "start_time": "2026-03-04T15:10:56.881853700Z"
    }
   },
   "source": [
    "# clip (clamp values to a range)\n",
    "x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt\n",
    "print('clip [1V, 3V]:', u.math.clip(x, 1.0 * u.volt, 3.0 * u.volt))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clip [1V, 3V]: [1.  1.5 2.5 3.  3. ] V\n"
     ]
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:56.976611600Z",
     "start_time": "2026-03-04T15:10:56.932362700Z"
    }
   },
   "source": [
    "# diff: discrete derivative\n",
    "t = jnp.array([0., 1., 4., 9., 16.]) * u.second\n",
    "print('diff:', u.math.diff(t))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "diff: [1. 3. 5. 7.] s\n"
     ]
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:57.033508100Z",
     "start_time": "2026-03-04T15:10:56.977834500Z"
    }
   },
   "source": [
    "# where: conditional selection\n",
    "cond = jnp.array([True, False, True, False])\n",
    "a = jnp.array([1., 2., 3., 4.]) * u.meter\n",
    "b = jnp.array([10., 20., 30., 40.]) * u.meter\n",
    "print('where:', u.math.where(cond, a, b))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "where: [ 1. 20.  3. 40.] m\n"
     ]
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Functions Removing Unit\n",
    "\n",
    "These functions return dimensionless results (booleans, indices, etc.).\n",
    "\n",
    "Includes: `equal`, `greater`, `less`, `isclose`, `allclose`, `argmax`, `argmin`, `argsort`, `nonzero`, `sign`, `all`, `any`, `count_nonzero`, and more."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "start_time": "2026-03-04T15:10:57.034975400Z"
    }
   },
   "source": [
    "# Comparisons return boolean arrays (dimensionless)\n",
    "a = jnp.array([1., 2., 3.]) * u.volt\n",
    "b = jnp.array([2., 2., 2.]) * u.volt\n",
    "print('equal:', u.math.equal(a, b))\n",
    "print('greater:', u.math.greater(a, b))\n",
    "print('less:', u.math.less(a, b))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "equal: [False  True False]\n",
      "greater: [False False  True]\n"
     ]
    }
   ],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# isclose and allclose for approximate comparison\n",
    "x = jnp.array([1.0, 2.0, 3.0]) * u.meter\n",
    "y = jnp.array([1.0, 2.00001, 3.0]) * u.meter\n",
    "print('isclose:', u.math.isclose(x, y))\n",
    "print('allclose:', u.math.allclose(x, y))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# Index-finding functions\n",
    "data = jnp.array([10., 5., 30., 15., 25.]) * u.watt\n",
    "print('argmax:', u.math.argmax(data))  # index of max\n",
    "print('argmin:', u.math.argmin(data))  # index of min\n",
    "print('argsort:', u.math.argsort(data))  # indices that would sort"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# sign and logical operations\n",
    "mixed = jnp.array([-3., 0., 5., -1., 2.]) * u.ampere\n",
    "print('sign:', u.math.sign(mixed))\n",
    "print('any > 0:', u.math.any(u.math.greater(mixed, 0 * u.ampere)))\n",
    "print('all > 0:', u.math.all(u.math.greater(mixed, 0 * u.ampere)))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "# count_nonzero\n",
    "sparse_data = jnp.array([0., 1., 0., 0., 3., 0., 2.]) * u.volt\n",
    "print('count_nonzero:', u.math.count_nonzero(sparse_data))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quick Reference: Unit Behavior by Category\n",
    "\n",
    "| Category | Example Functions | Unit Behavior |\n",
    "|----------|------------------|---------------|\n",
    "| **Array Creation** | `zeros`, `ones`, `arange`, `linspace` | Set via `unit=` parameter |\n",
    "| **Accept Unitless** | `sin`, `cos`, `exp`, `log` | Input must be dimensionless |\n",
    "| **Change Unit** | `multiply`, `sqrt`, `dot`, `outer` | Output unit derived from operation |\n",
    "| **Keep Unit** | `sum`, `mean`, `sort`, `reshape`, `clip` | Same unit as input |\n",
    "| **Remove Unit** | `equal`, `argmax`, `sign`, `allclose` | Dimensionless output |\n",
    "\n",
    "For the complete function listing, see the [API documentation](https://brainunit.readthedocs.io/apis/brainunit.math.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
