{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quickstart Guide\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/getting_started/quickstart.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/getting_started/quickstart.ipynb)\n",
    "\n",
    "**brainunit** is a unit-aware scientific computing library built on JAX.\n",
    "It tracks physical units through all computations — arithmetic, linear algebra, FFTs,\n",
    "automatic differentiation, and JIT compilation — catching dimension errors at runtime.\n",
    "\n",
    "This guide covers the essentials in 5 minutes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "```bash\n",
    "pip install brainunit\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:13.544066400Z",
     "start_time": "2026-03-04T15:10:12.644493500Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Quantities\n",
    "\n",
    "A `Quantity` = numeric value + physical unit. Create one by multiplying a value with a unit."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:13.778319700Z",
     "start_time": "2026-03-04T15:10:13.544066400Z"
    }
   },
   "source": [
    "# Scalars\n",
    "mass = 5.0 * u.kilogram\n",
    "speed = 10.0 * u.meter / u.second\n",
    "print('mass:', mass)\n",
    "print('speed:', speed)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mass: 5. kg\n",
      "speed: 10. m / s\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.000740100Z",
     "start_time": "2026-03-04T15:10:13.824224300Z"
    }
   },
   "source": [
    "# Arrays\n",
    "voltages = jnp.array([1.0, 2.5, 3.7]) * u.mV\n",
    "print('voltages:', voltages)\n",
    "print('shape:', voltages.shape, 'dtype:', voltages.dtype)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "voltages: [1.  2.5 3.70000005] mV\n",
      "shape: (3,) dtype: float32\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.048193900Z",
     "start_time": "2026-03-04T15:10:14.025932200Z"
    }
   },
   "source": [
    "# Direct construction\n",
    "current = u.Quantity(jnp.array([0.1, 0.2, 0.3]), unit=u.ampere)\n",
    "print('current:', current)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "current: [0.1 0.2 0.30000001] A\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arithmetic with Units\n",
    "\n",
    "Units are tracked automatically. Incompatible operations raise errors."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.074540100Z",
     "start_time": "2026-03-04T15:10:14.049697800Z"
    }
   },
   "source": [
    "# Addition: same dimension required\n",
    "t1 = 500.0 * u.ms\n",
    "t2 = 1.5 * u.second\n",
    "print('t1 + t2:', t1 + t2)  # auto-aligns to first unit"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t1 + t2: 2000. ms\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.146356600Z",
     "start_time": "2026-03-04T15:10:14.076610Z"
    }
   },
   "source": [
    "# Multiplication: units multiply\n",
    "F = 10.0 * u.newton\n",
    "d = 3.0 * u.meter\n",
    "print('work = F * d:', F * d)  # N * m = J"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "work = F * d: 30. J\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.164365200Z",
     "start_time": "2026-03-04T15:10:14.146356600Z"
    }
   },
   "source": [
    "# Division: units divide\n",
    "print('speed = d / t:', (100.0 * u.meter) / (10.0 * u.second))  # m/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "speed = d / t: 10. m / s\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.183800300Z",
     "start_time": "2026-03-04T15:10:14.167796100Z"
    }
   },
   "source": [
    "# Dimension mismatch raises error\n",
    "try:\n",
    "    result = 5.0 * u.meter + 3.0 * u.second\n",
    "except Exception as e:\n",
    "    print('Error:', e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: Cannot calculate \n",
      "5. m + 3. s, because units do not match: m != s\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unit Conversion\n",
    "\n",
    "Use `to_decimal()` to extract the numeric value in a target unit,\n",
    "or `in_unit()` to get a new Quantity in the target unit."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.209327100Z",
     "start_time": "2026-03-04T15:10:14.196807300Z"
    }
   },
   "source": [
    "distance = 2.5 * u.kmeter\n",
    "print('In meters:', distance.to_decimal(u.meter))       # 2500.0\n",
    "print('In cm:', distance.to_decimal(u.cmeter))           # 250000.0\n",
    "print('As Quantity:', distance.in_unit(u.meter))          # 2500.0 m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In meters: 2500.0\n",
      "In cm: 250000.0\n",
      "As Quantity: 2500. m\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantity Attributes"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.232815600Z",
     "start_time": "2026-03-04T15:10:14.210643400Z"
    }
   },
   "source": [
    "q = jnp.array([[1., 2.], [3., 4.]]) * u.volt\n",
    "print('mantissa:', q.mantissa)   # numeric array\n",
    "print('unit:', q.unit)           # the unit\n",
    "print('dim:', q.dim)             # physical dimension\n",
    "print('shape:', q.shape)         # array shape\n",
    "print('dtype:', q.dtype)         # array dtype"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mantissa: [[1. 2.]\n",
      " [3. 4.]]\n",
      "unit: V\n",
      "dim: m^2 kg s^-3 A^-1\n",
      "shape: (2, 2)\n",
      "dtype: float32\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unit-Aware Math Functions\n",
    "\n",
    "`brainunit.math` provides 500+ functions that understand units."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.380438400Z",
     "start_time": "2026-03-04T15:10:14.233931Z"
    }
   },
   "source": [
    "data = jnp.array([2., 4., 6., 8., 10.]) * u.newton\n",
    "\n",
    "print('sum:', u.math.sum(data))           # keeps unit\n",
    "print('mean:', u.math.mean(data))         # keeps unit\n",
    "print('sqrt:', u.math.sqrt(4.0 * u.meter2))  # changes unit: m^2 -> m\n",
    "print('sort:', u.math.sort(jnp.array([3., 1., 2.]) * u.volt))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum: 30. N\n",
      "mean: 6. N\n",
      "sqrt: 2. m\n",
      "sort: [1. 2. 3.] V\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physical Constants"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.411745700Z",
     "start_time": "2026-03-04T15:10:14.396388100Z"
    }
   },
   "source": [
    "from brainunit import constants\n",
    "\n",
    "print('Avogadro number:', constants.avogadro)\n",
    "print('Boltzmann constant:', constants.boltzmann)\n",
    "print('Elementary charge:', constants.elementary_charge)\n",
    "print('Electron mass:', constants.electron_mass)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avogadro number: 6.0221406e+23 1 / mol\n",
      "Boltzmann constant: 1.380649e-23 J / K\n",
      "Elementary charge: 1.6021766e-19 C\n",
      "Electron mass: 9.109383e-31 kg\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## JAX Transforms: `jit`, `vmap`, `grad`\n",
    "\n",
    "Quantities work seamlessly with JAX transformations."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.451918200Z",
     "start_time": "2026-03-04T15:10:14.411745700Z"
    }
   },
   "source": [
    "# JIT compilation\n",
    "@jax.jit\n",
    "def kinetic_energy(m, v):\n",
    "    return 0.5 * m * v**2\n",
    "\n",
    "KE = kinetic_energy(2.0 * u.kilogram, 3.0 * u.meter / u.second)\n",
    "print('KE =', KE)  # kg * m^2 / s^2 = J"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KE = 9. J\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.552472800Z",
     "start_time": "2026-03-04T15:10:14.452936100Z"
    }
   },
   "source": [
    "# vmap: vectorize over a batch\n",
    "velocities = jnp.array([1., 2., 3., 4., 5.]) * u.meter / u.second\n",
    "energies = jax.vmap(lambda v: kinetic_energy(2.0 * u.kilogram, v))(velocities)\n",
    "print('Batch KE:', energies)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch KE: [ 1.  4.  9. 16. 25.] J\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:14.932069600Z",
     "start_time": "2026-03-04T15:10:14.554579700Z"
    }
   },
   "source": [
    "# grad: automatic differentiation with unit tracking\n",
    "dKE_dv = u.autograd.grad(lambda v: 0.5 * (2.0 * u.kilogram) * v**2)\n",
    "print('dKE/dv at v=3 m/s:', dKE_dv(3.0 * u.meter / u.second))  # momentum: kg * m/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dKE/dv at v=3 m/s: 6. kg * m / s\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unit Validation with Decorators\n",
    "\n",
    "Use `@check_units` to enforce unit contracts on function arguments."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:15.012564900Z",
     "start_time": "2026-03-04T15:10:14.956779Z"
    }
   },
   "source": [
    "@u.check_units(v=u.meter / u.second, t=u.second)\n",
    "def displacement(v, t):\n",
    "    return v * t\n",
    "\n",
    "print('displacement:', displacement(10.0 * u.meter / u.second, 5.0 * u.second))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "displacement: 50. m\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:15.026082700Z",
     "start_time": "2026-03-04T15:10:15.013566800Z"
    }
   },
   "source": [
    "# Wrong units raise an error\n",
    "try:\n",
    "    displacement(10.0 * u.kilogram, 5.0 * u.second)\n",
    "except Exception as e:\n",
    "    print('Error:', e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error: Function 'displacement' expected a array with unit Unit(\"m / s\") for argument 'v' but got '10. kg' (unit is kg).\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What's Next?\n",
    "\n",
    "- **[Quantity](../physical_units/quantity.ipynb)** — Creating and manipulating quantities in depth\n",
    "- **[Standard Units](../physical_units/standard_units.ipynb)** — All available SI and non-SI units\n",
    "- **[Unit Conversion](../physical_units/conversion.ipynb)** — Converting between units\n",
    "- **[NumPy Functions](../unit_operations/numpy_functions.ipynb)** — 500+ unit-aware math functions\n",
    "- **[Linear Algebra](../unit_operations/linalg_functions.ipynb)** — Unit-aware linalg\n",
    "- **[Unit Validation](../unit_operations/check_units.ipynb)** — check_dims and check_units decorators"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
