{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Array Creation\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/array_creation.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/array_creation.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "The functions listed below are designed to create `array` or `Quantity` with specific properties, such as filled with a certain value, identity matrices, or arrays with ones on the diagonal. These functions are part of the `brainunit.math` module and are tailored to handle both numerical `array` and `Quantity` with units."
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:31.388017800Z",
     "start_time": "2026-03-04T15:10:30.451361100Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.array` & `brainunit.math.asarray`\n",
    "\n",
    "Convert the input to a quantity or array.\n",
    "\n",
    "  If unit is provided, the input will be checked whether it has the same unit as the provided unit.\n",
    "  (If they have same dimension but different magnitude, the input will be converted to the provided unit.)\n",
    "  If unit is not provided, the input will be converted to an array."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:31.992977100Z",
     "start_time": "2026-03-04T15:10:31.406532700Z"
    }
   },
   "source": [
    "u.math.asarray([1, 2, 3])                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1, 2, 3], dtype=int32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.261432Z",
     "start_time": "2026-03-04T15:10:32.090526600Z"
    }
   },
   "source": [
    "u.math.asarray([1, 2, 3], unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1, 2, 3], dtype=int32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.366911800Z",
     "start_time": "2026-03-04T15:10:32.263434200Z"
    }
   },
   "source": [
    "# check if the input has the same unit as the provided unit\n",
    "u.math.asarray([1 * u.second, 2 * u.second], unit=u.second)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([1 2], \"s\")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.437720200Z",
     "start_time": "2026-03-04T15:10:32.370083Z"
    }
   },
   "source": [
    "# fails because the input has a different unit\n",
    "try:\n",
    "    u.math.asarray([1 * u.second, 2 * u.second], unit=u.ampere)\n",
    "except Exception as e:\n",
    "    print(e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cannot convert to a unit with different dimensions. (units are s and A).\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.arange`\n",
    "Return evenly spaced values within a given interval."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.515230900Z",
     "start_time": "2026-03-04T15:10:32.463641300Z"
    }
   },
   "source": [
    "u.math.arange(5)                                    # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 1, 2, 3, 4], dtype=int32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.582604Z",
     "start_time": "2026-03-04T15:10:32.516238100Z"
    }
   },
   "source": [
    "u.math.arange(5 * u.second, step=1 * u.second) # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([0 1 2 3 4], \"s\")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.615306100Z",
     "start_time": "2026-03-04T15:10:32.583708900Z"
    }
   },
   "source": [
    "u.math.arange(3, 9, 1)                                          # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([3, 4, 5, 6, 7, 8], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.683835600Z",
     "start_time": "2026-03-04T15:10:32.615306100Z"
    }
   },
   "source": [
    "u.math.arange(3 * u.second, 9 * u.second, 1 * u.second)   # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([3 4 5 6 7 8], \"s\")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.array_split`\n",
    "Split an array into multiple sub-arrays."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.773784800Z",
     "start_time": "2026-03-04T15:10:32.686941100Z"
    }
   },
   "source": [
    "a = jnp.arange(9)\n",
    "\n",
    "u.math.array_split(a, 3)      # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Array([0, 1, 2], dtype=int32),\n",
       " Array([3, 4, 5], dtype=int32),\n",
       " Array([6, 7, 8], dtype=int32)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:32.824553800Z",
     "start_time": "2026-03-04T15:10:32.774820Z"
    }
   },
   "source": [
    "q = jnp.arange(9) * u.second\n",
    "\n",
    "u.math.array_split(q, 3)   # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Quantity([0 1 2], \"s\"), Quantity([3 4 5], \"s\"), Quantity([6 7 8], \"s\")]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.linspace`\n",
    "Return evenly spaced numbers over a specified interval."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.062212900Z",
     "start_time": "2026-03-04T15:10:32.825625100Z"
    }
   },
   "source": [
    "u.math.linspace(0, 10, 5)                               # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.179765100Z",
     "start_time": "2026-03-04T15:10:33.085363600Z"
    }
   },
   "source": [
    "u.math.linspace(0 * u.second, 10 * u.second, 5)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([ 0.   2.5  5.   7.5 10. ], \"s\")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.logspace`\n",
    "Return numbers spaced evenly on a log scale."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.431479100Z",
     "start_time": "2026-03-04T15:10:33.223035300Z"
    }
   },
   "source": [
    "u.math.logspace(0, 10, 5)                               # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1.0000000e+00, 3.1622775e+02, 1.0000000e+05, 3.1622776e+07,\n",
       "       1.0000000e+10], dtype=float32)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.561577300Z",
     "start_time": "2026-03-04T15:10:33.456551400Z"
    }
   },
   "source": [
    "u.math.logspace(0 * u.second, 10 * u.second, 5)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([1.0000000e+00 3.1622775e+02 1.0000000e+05 3.1622776e+07 1.0000000e+10], \"s\")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.meshgrid`\n",
    "Return coordinate matrices from coordinate vectors."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.737786300Z",
     "start_time": "2026-03-04T15:10:33.570692700Z"
    }
   },
   "source": [
    "x = jnp.array([1, 2, 3])\n",
    "y = jnp.array([4, 5])\n",
    "\n",
    "u.math.meshgrid(x, y)           # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Array([[1, 2, 3],\n",
       "        [1, 2, 3]], dtype=int32),\n",
       " Array([[4, 4, 4],\n",
       "        [5, 5, 5]], dtype=int32)]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:33.969708300Z",
     "start_time": "2026-03-04T15:10:33.761812Z"
    }
   },
   "source": [
    "x_q = jnp.array([1, 2, 3]) * u.second\n",
    "y_q = jnp.array([4, 5]) * u.second\n",
    "\n",
    "u.math.meshgrid(x_q, y_q)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Quantity([[1 2 3]\n",
       "           [1 2 3]], \"s\"),\n",
       " Quantity([[4 4 4]\n",
       "           [5 5 5]], \"s\")]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `brainunit.math.vander`\n",
    "Generate a Vandermonde matrix.\n",
    "\n",
    "The Vandermonde matrix is a matrix with the terms of a geometric progression in each row.\n",
    "  The geometric progression is defined by the vector `x` and the number of columns `N`.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.079169700Z",
     "start_time": "2026-03-04T15:10:34.000515600Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.vander(a)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1, 1, 1],\n",
       "       [4, 2, 1],\n",
       "       [9, 3, 1]], dtype=int32)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Can use with `Quantity`\n",
    "\n",
    "The functions below can be used with `Quantity` with units."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.full`\n",
    "Returns a quantity or array filled with a specific value."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.140831700Z",
     "start_time": "2026-03-04T15:10:34.079169700Z"
    }
   },
   "source": [
    "u.math.full(3, 4)                   # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([4, 4, 4], dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.201697300Z",
     "start_time": "2026-03-04T15:10:34.141831400Z"
    }
   },
   "source": [
    "u.math.full(3, 4 * u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([4 4 4], \"s\")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.empty`\n",
    "Return a new quantity or array of given shape and type, without initializing entries."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.270603600Z",
     "start_time": "2026-03-04T15:10:34.204732600Z"
    }
   },
   "source": [
    "u.math.empty((2, 2))                    # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[0., 0.],\n",
       "       [0., 0.]], dtype=float32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.329150900Z",
     "start_time": "2026-03-04T15:10:34.272586500Z"
    }
   },
   "source": [
    "u.math.empty((2, 2), unit=u.second) # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[0. 0.]\n",
       "          [0. 0.]], \"s\")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.ones`\n",
    "Returns a new quantity or array of given shape and type, filled with ones."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.343505600Z",
     "start_time": "2026-03-04T15:10:34.330542100Z"
    }
   },
   "source": [
    "u.math.ones((2, 2))                     # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1., 1.],\n",
       "       [1., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 23
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.357404400Z",
     "start_time": "2026-03-04T15:10:34.343505600Z"
    }
   },
   "source": [
    "u.math.ones((2, 2), unit=u.second)  # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1. 1.]\n",
       "          [1. 1.]], \"s\")"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.zeros`\n",
    "Returns a new quantity or array of given shape and type, filled with ones."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.370243400Z",
     "start_time": "2026-03-04T15:10:34.357404400Z"
    }
   },
   "source": [
    "u.math.zeros((2, 2))                    # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[0., 0.],\n",
       "       [0., 0.]], dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.384978100Z",
     "start_time": "2026-03-04T15:10:34.370243400Z"
    }
   },
   "source": [
    "u.math.zeros((2, 2), unit=u.second) # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[0. 0.]\n",
       "          [0. 0.]], \"s\")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.full_like`\n",
    "Return a new quantity or array with the same shape and type as a given array or quantity, filled with `fill_value`.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.430959600Z",
     "start_time": "2026-03-04T15:10:34.384978100Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.full_like(a, 4)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([4, 4, 4], dtype=int32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "code",
   "source": [
    "try:\n",
    "    u.math.full_like(a, 4 * u.second)         # return a Quantity\n",
    "except Exception as e:\n",
    "    print(e)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.442427300Z",
     "start_time": "2026-03-04T15:10:34.430959600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full_like requires \"fill_value\" to be dimensionless when \"a\" is a plain array, but got fill_value with unit=s. Either pass a plain number as fill_value or wrap \"a\" as a Quantity.\n"
     ]
    }
   ],
   "execution_count": 28
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.empty_like`\n",
    "Return a new quantity or array with the same shape and type as a given array.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.456820700Z",
     "start_time": "2026-03-04T15:10:34.443476Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.empty_like(a)       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 0, 0], dtype=int32)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 29
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.474343300Z",
     "start_time": "2026-03-04T15:10:34.458282600Z"
    }
   },
   "source": [
    "q = jnp.array([1, 2, 3]) * u.second\n",
    "\n",
    "u.math.empty_like(q)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([0 0 0], \"s\")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 30
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.ones_like`\n",
    "Return a new quantity or array with the same shape and type as a given array, filled with ones."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.584606100Z",
     "start_time": "2026-03-04T15:10:34.510794500Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.ones_like(a)       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1, 1, 1], dtype=int32)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 31
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:34.606396100Z",
     "start_time": "2026-03-04T15:10:34.586050900Z"
    }
   },
   "source": [
    "q = jnp.array([1, 2, 3]) * u.second\n",
    "\n",
    "u.math.ones_like(q)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([1 1 1], \"s\")"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 32
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.zeros_like`\n",
    "Return a new quantity or array with the same shape and type as a given array, filled with zeros."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.017005400Z",
     "start_time": "2026-03-04T15:10:34.673237600Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.zeros_like(a)       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0, 0, 0], dtype=int32)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 33
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.229875600Z",
     "start_time": "2026-03-04T15:10:35.075363900Z"
    }
   },
   "source": [
    "q = jnp.array([1, 2, 3]) * u.second\n",
    "\n",
    "u.math.zeros_like(q)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([0 0 0], \"s\")"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 34
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.fill_diagonal`\n",
    "Fill the main diagonal of the given array of any dimensionality."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.598787Z",
     "start_time": "2026-03-04T15:10:35.258585900Z"
    }
   },
   "source": [
    "a = jnp.zeros((3, 3))\n",
    "\n",
    "u.math.fill_diagonal(a, 4)      # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[4., 0., 0.],\n",
       "       [0., 4., 0.],\n",
       "       [0., 0., 4.]], dtype=float32)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 35
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.681172800Z",
     "start_time": "2026-03-04T15:10:35.599294700Z"
    }
   },
   "source": [
    "q = jnp.zeros((3, 3)) * u.second\n",
    "\n",
    "u.math.fill_diagonal(q, 4 * u.second)   # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[4. 0. 0.]\n",
       "          [0. 4. 0.]\n",
       "          [0. 0. 4.]], \"s\")"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 36
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Can use with `unit` keyword\n",
    "\n",
    "The functions below can be used with the `unit` keyword."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.eye`\n",
    "Returns a 2-D quantity or array with ones on the diagonal and zeros elsewhere."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:35.938555200Z",
     "start_time": "2026-03-04T15:10:35.684617Z"
    }
   },
   "source": [
    "u.math.eye(3)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1., 0., 0.],\n",
       "       [0., 1., 0.],\n",
       "       [0., 0., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 37
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.053545400Z",
     "start_time": "2026-03-04T15:10:35.942874100Z"
    }
   },
   "source": [
    "u.math.eye(3, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1. 0. 0.]\n",
       "          [0. 1. 0.]\n",
       "          [0. 0. 1.]], \"s\")"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 38
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.identity`\n",
    "Return the identity Quantity or array."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.087742400Z",
     "start_time": "2026-03-04T15:10:36.056066300Z"
    }
   },
   "source": [
    "u.math.identity(3)                  # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1., 0., 0.],\n",
       "       [0., 1., 0.],\n",
       "       [0., 0., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 39
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.175455300Z",
     "start_time": "2026-03-04T15:10:36.114283200Z"
    }
   },
   "source": [
    "u.math.identity(3, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1. 0. 0.]\n",
       "          [0. 1. 0.]\n",
       "          [0. 0. 1.]], \"s\")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 40
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.tri`\n",
    "Returns A quantity or an array with ones at and below the given diagonal and zeros elsewhere.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.273780900Z",
     "start_time": "2026-03-04T15:10:36.177465100Z"
    }
   },
   "source": [
    "u.math.tri(3)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1., 0., 0.],\n",
       "       [1., 1., 0.],\n",
       "       [1., 1., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 41
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.298065500Z",
     "start_time": "2026-03-04T15:10:36.275779900Z"
    }
   },
   "source": [
    "u.math.tri(3, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1. 0. 0.]\n",
       "          [1. 1. 0.]\n",
       "          [1. 1. 1.]], \"s\")"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 42
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.diag`\n",
    "Extract a diagonal or construct a diagonal array."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.355291300Z",
     "start_time": "2026-03-04T15:10:36.300065900Z"
    }
   },
   "source": [
    "a = jnp.array([1, 2, 3])\n",
    "\n",
    "u.math.diag(a)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1, 0, 0],\n",
       "       [0, 2, 0],\n",
       "       [0, 0, 3]], dtype=int32)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 43
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.412976300Z",
     "start_time": "2026-03-04T15:10:36.356729500Z"
    }
   },
   "source": [
    "u.math.diag(a, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1 0 0]\n",
       "          [0 2 0]\n",
       "          [0 0 3]], \"s\")"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 44
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.tril`\n",
    "Lower triangle of an array.\n",
    "\n",
    "  Return a copy of a matrix with the elements above the `k`-th diagonal zeroed.\n",
    "  For quantities or arrays with ``ndim`` exceeding 2, `tril` will apply to the final two axes.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.468278100Z",
     "start_time": "2026-03-04T15:10:36.413977200Z"
    }
   },
   "source": [
    "a = jnp.ones((3, 3))\n",
    "\n",
    "u.math.diag(a)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1., 1., 1.], dtype=float32)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 45
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.508031900Z",
     "start_time": "2026-03-04T15:10:36.469279800Z"
    }
   },
   "source": [
    "u.math.diag(a, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([1. 1. 1.], \"s\")"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 46
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `brainunit.math.triu`\n",
    "Upper triangle of an array.\n",
    "\n",
    "  Return a copy of a matrix with the elements below the `k`-th diagonal zeroed.\n",
    "  For quantities or arrays with ``ndim`` exceeding 2, `triu` will apply to the final two axes."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.557376100Z",
     "start_time": "2026-03-04T15:10:36.509262100Z"
    }
   },
   "source": [
    "a = jnp.ones((3, 3))\n",
    "\n",
    "u.math.tril(a)                       # return a jax.Array"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[1., 0., 0.],\n",
       "       [1., 1., 0.],\n",
       "       [1., 1., 1.]], dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 47
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:36.576482900Z",
     "start_time": "2026-03-04T15:10:36.558375900Z"
    }
   },
   "source": [
    "u.math.tril(a, unit=u.second)    # return a Quantity"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Quantity([[1. 0. 0.]\n",
       "          [1. 1. 0.]\n",
       "          [1. 1. 1.]], \"s\")"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 48
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
