{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# JAX LAX Functions\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/lax_functions.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/lax_functions.ipynb)\n",
    "\n",
    "`brainunit.lax` provides unit-aware wrappers around JAX's low-level `jax.lax` primitives.\n",
    "These are the building blocks that higher-level functions (like `brainunit.math`) are built on.\n",
    "\n",
    "The functions are grouped by how they handle units:\n",
    "\n",
    "- **Keeping unit**: slicing, sorting, cumulative ops, padding, broadcasting\n",
    "- **Changing unit**: arithmetic (`mul`, `div`, `integer_pow`), `rsqrt`, `dot_general`, `batch_matmul`, `conv`\n",
    "- **Removing unit**: comparisons (`eq`, `lt`, `gt`, ...)\n",
    "- **Accepting unitless**: trig, special functions (`erf`, `logistic`, `bessel`)\n",
    "- **Linear algebra**: `cholesky`, `eig`, `qr`, `svd`, `triangular_solve`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.216267200Z",
     "start_time": "2026-03-04T15:10:43.348812500Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Keep Unit\n",
    "\n",
    "These operations rearrange, slice, or accumulate values without changing units."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Slicing Operations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.750514400Z",
     "start_time": "2026-03-04T15:10:44.216267200Z"
    }
   },
   "source": [
    "x = jnp.array([10., 20., 30., 40., 50.]) * u.volt\n",
    "\n",
    "# Static slice: elements from index 1 to 4\n",
    "print('slice [1:4]:', u.lax.slice(x, (1,), (4,)))\n",
    "\n",
    "# Dynamic slice: start at index 2, take 3 elements\n",
    "print('dynamic_slice:', u.lax.dynamic_slice(x, (2,), (3,)))\n",
    "\n",
    "# Slice in a specific dimension\n",
    "M = jnp.arange(12.).reshape(3, 4) * u.ampere\n",
    "print('slice_in_dim (rows 0:2):')\n",
    "print(u.lax.slice_in_dim(M, 0, 2, axis=0))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "slice [1:4]: [20. 30. 40.] V\n",
      "dynamic_slice: [30. 40. 50.] V\n",
      "slice_in_dim (rows 0:2):\n",
      "[[0. 1. 2. 3.]\n",
      " [4. 5. 6. 7.]] A\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dynamic Updates"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.897169900Z",
     "start_time": "2026-03-04T15:10:44.847563700Z"
    }
   },
   "source": [
    "arr = jnp.array([1., 2., 3., 4., 5.]) * u.meter\n",
    "update = jnp.array([99., 88.]) * u.meter\n",
    "\n",
    "# Update a slice starting at index 1\n",
    "result = u.lax.dynamic_update_slice(arr, update, (1,))\n",
    "print('dynamic_update_slice:', result)  # [1, 99, 88, 4, 5] m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dynamic_update_slice: [ 1. 99. 88.  4.  5.] m\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sorting"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:44.976472300Z",
     "start_time": "2026-03-04T15:10:44.897169900Z"
    }
   },
   "source": [
    "unsorted = jnp.array([3., 1., 4., 1., 5., 9., 2., 6.]) * u.newton\n",
    "print('sort:', u.lax.sort(unsorted))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sort: [1. 1. 2. 3. 4. 5. 6. 9.] N\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:45.008861500Z",
     "start_time": "2026-03-04T15:10:44.976472300Z"
    }
   },
   "source": [
    "# Top-k: largest k elements\n",
    "values, indices = u.lax.top_k(unsorted, 3)\n",
    "print('top_k values:', values)\n",
    "print('top_k indices:', indices)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "top_k values: [9. 6. 5.] N\n",
      "top_k indices: [5 7 4]\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cumulative Operations"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:45.619488100Z",
     "start_time": "2026-03-04T15:10:45.009869200Z"
    }
   },
   "source": [
    "vals = jnp.array([1., 3., 2., 5., 4.]) * u.watt\n",
    "\n",
    "print('cumsum:', u.lax.cumsum(vals, axis=0))\n",
    "print('cummin:', u.lax.cummin(vals, axis=0))\n",
    "print('cummax:', u.lax.cummax(vals, axis=0))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cumsum: [ 1.  4.  6. 11. 15.] W\n",
      "cummin: [1. 1. 1. 1. 1.] W\n",
      "cummax: [1. 3. 3. 5. 5.] W\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Padding"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:45.831687600Z",
     "start_time": "2026-03-04T15:10:45.718415800Z"
    }
   },
   "source": [
    "signal = jnp.array([1., 2., 3.]) * u.volt\n",
    "\n",
    "# Pad with 2 zeros on the left and 1 on the right\n",
    "padded = u.lax.pad(signal, 0.0 * u.volt, [(2, 1, 0)])\n",
    "print('padded:', padded)  # [0, 0, 1, 2, 3, 0] V"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "padded: [0. 0. 1. 2. 3. 0.] V\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clamping"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:45.889951Z",
     "start_time": "2026-03-04T15:10:45.833284800Z"
    }
   },
   "source": [
    "x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt\n",
    "clamped = u.lax.clamp(1.0 * u.volt, x, 3.0 * u.volt)\n",
    "print('clamp [1V, 3V]:', clamped)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clamp [1V, 3V]: [1.  1.5 2.5 3.  3. ] V\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Broadcasting"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:45.976306400Z",
     "start_time": "2026-03-04T15:10:45.891961100Z"
    }
   },
   "source": [
    "v = jnp.array([1., 2., 3.]) * u.meter\n",
    "print('broadcast to (4, 3):')\n",
    "print(u.lax.broadcast(v, (4,)))  # adds a leading dimension of size 4"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "broadcast to (4, 3):\n",
      "[[1. 2. 3.]\n",
      " [1. 2. 3.]\n",
      " [1. 2. 3.]\n",
      " [1. 2. 3.]] m\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Negation"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.017989700Z",
     "start_time": "2026-03-04T15:10:45.977313400Z"
    }
   },
   "source": [
    "print('neg:', u.lax.neg(jnp.array([1., -2., 3.]) * u.pascal))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "neg: [-1.  2. -3.] Pa\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Type Conversion"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.078001700Z",
     "start_time": "2026-03-04T15:10:46.019332500Z"
    }
   },
   "source": [
    "x_int = jnp.array([1, 2, 3]) * u.meter\n",
    "x_float = u.lax.convert_element_type(x_int, jnp.float32)\n",
    "print('int to float:', x_float)\n",
    "print('dtype:', x_float.dtype)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "int to float: [1. 2. 3.] m\n",
      "dtype: float32\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Change Unit\n",
    "\n",
    "These arithmetic and algebraic operations produce results with different units."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Arithmetic: `mul`, `div`, `sub`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.138234600Z",
     "start_time": "2026-03-04T15:10:46.078998600Z"
    }
   },
   "source": [
    "current = jnp.array([1., 2., 3.]) * u.ampere\n",
    "resistance = jnp.array([10., 20., 30.]) * u.ohm\n",
    "\n",
    "# Ohm's law: V = I * R\n",
    "voltage = u.lax.mul(current, resistance)\n",
    "print('V = I * R:', voltage)  # ampere * ohm = volt"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "V = I * R: [10. 40. 90.] V\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.197545400Z",
     "start_time": "2026-03-04T15:10:46.138234600Z"
    }
   },
   "source": [
    "# Division\n",
    "power = jnp.array([100., 200.]) * u.watt\n",
    "v = jnp.array([10., 20.]) * u.volt\n",
    "print('P / V:', u.lax.div(power, v))  # watt / volt = ampere"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P / V: [10. 10.] A\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `integer_pow` — Power with integer exponent"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.442012800Z",
     "start_time": "2026-03-04T15:10:46.202109700Z"
    }
   },
   "source": [
    "lengths = jnp.array([2., 3., 4.]) * u.meter\n",
    "print('squared:', u.lax.integer_pow(lengths, 2))  # m^2\n",
    "print('cubed:', u.lax.integer_pow(lengths, 3))    # m^3"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "squared: [ 4.  9. 16.] m^2\n",
      "cubed: [ 8. 27. 64.] m^3\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `rsqrt` — Reciprocal square root\n",
    "\n",
    "For input with unit `u`, result has unit `1/sqrt(u)`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.533354600Z",
     "start_time": "2026-03-04T15:10:46.445424900Z"
    }
   },
   "source": [
    "areas = jnp.array([4., 9., 16.]) * u.meter2\n",
    "print('rsqrt:', u.lax.rsqrt(areas))  # 1/m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rsqrt: [0.5        0.33333334 0.25      ] 1 / m\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `dot_general` — Generalized dot product"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.651389200Z",
     "start_time": "2026-03-04T15:10:46.535358700Z"
    }
   },
   "source": [
    "# Matrix multiplication via dot_general\n",
    "a = jnp.array([[1., 2.], [3., 4.]]) * u.meter\n",
    "b = jnp.array([[5., 6.], [7., 8.]]) * u.second\n",
    "\n",
    "# Contract over last axis of a and first axis of b\n",
    "result = u.lax.dot_general(a, b, (((1,), (0,)), ((), ())))\n",
    "print('dot_general (matmul):')\n",
    "print(result)  # m * s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dot_general (matmul):\n",
      "[[19. 22.]\n",
      " [43. 50.]] m * s\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `batch_matmul` — Batched matrix multiplication"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.780381Z",
     "start_time": "2026-03-04T15:10:46.659336700Z"
    }
   },
   "source": [
    "# Batch of 2 matrices: (batch=2, rows=3, cols=4) @ (batch=2, rows=4, cols=2)\n",
    "A = jnp.ones((2, 3, 4)) * u.volt\n",
    "B = jnp.ones((2, 4, 2)) * u.ampere\n",
    "\n",
    "C = u.lax.batch_matmul(A, B)\n",
    "print('batch_matmul shape:', C.shape)  # (2, 3, 2)\n",
    "print('batch_matmul unit:', C.unit)    # V * A = W"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch_matmul shape: (2, 3, 2)\n",
      "batch_matmul unit: W\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `rem` — Remainder"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.862935700Z",
     "start_time": "2026-03-04T15:10:46.781390200Z"
    }
   },
   "source": [
    "a = jnp.array([7., 10., 15.]) * u.meter\n",
    "b = jnp.array([3., 4., 6.]) * u.meter\n",
    "print('remainder:', u.lax.rem(a, b))  # [1, 2, 3] m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "remainder: [1. 2. 3.] m\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Remove Unit (Comparisons)\n",
    "\n",
    "Comparison operations return dimensionless boolean arrays."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.060552800Z",
     "start_time": "2026-03-04T15:10:46.864931800Z"
    }
   },
   "source": [
    "a = jnp.array([1., 3., 5.]) * u.volt\n",
    "b = jnp.array([2., 3., 4.]) * u.volt\n",
    "\n",
    "print('eq:', u.lax.eq(a, b))    # [F, T, F]\n",
    "print('ne:', u.lax.ne(a, b))    # [T, F, T]\n",
    "print('lt:', u.lax.lt(a, b))    # [T, F, F]\n",
    "print('le:', u.lax.le(a, b))    # [T, T, F]\n",
    "print('gt:', u.lax.gt(a, b))    # [F, F, T]\n",
    "print('ge:', u.lax.ge(a, b))    # [F, T, T]"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "eq: [False  True False]\n",
      "ne: [ True False  True]\n",
      "lt: [ True False False]\n",
      "le: [ True  True False]\n",
      "gt: [False False  True]\n",
      "ge: [False  True  True]\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions Accepting Unitless Input\n",
    "\n",
    "These functions (trigonometric, special functions) require dimensionless inputs."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.170556600Z",
     "start_time": "2026-03-04T15:10:47.061059300Z"
    }
   },
   "source": [
    "# Trigonometric functions require unitless input\n",
    "angles = jnp.array([0.0, 0.5, 1.0])\n",
    "print('asin:', u.lax.asin(angles))\n",
    "print('acos:', u.lax.acos(angles))\n",
    "print('atan:', u.lax.atan(angles))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "asin: [0.        0.5235988 1.5707964]\n",
      "acos: [1.5707964 1.0471976 0.       ]\n",
      "atan: [0.        0.4636476 0.7853982]\n"
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.275417Z",
     "start_time": "2026-03-04T15:10:47.172561800Z"
    }
   },
   "source": [
    "# Special functions\n",
    "x = jnp.array([0.0, 0.5, 1.0, 2.0])\n",
    "print('logistic:', u.lax.logistic(x))\n",
    "print('erf:', u.lax.erf(x))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logistic: [0.5        0.62245935 0.7310586  0.880797  ]\n",
      "erf: [0.         0.5204999  0.84270084 0.9953222 ]\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.431105400Z",
     "start_time": "2026-03-04T15:10:47.278435600Z"
    }
   },
   "source": [
    "# Bessel functions\n",
    "print('bessel_i0e:', u.lax.bessel_i0e(x))\n",
    "print('bessel_i1e:', u.lax.bessel_i1e(x))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bessel_i0e: [1.0000001  0.6450353  0.46575963 0.3085083 ]\n",
      "bessel_i1e: [0.         0.1564208  0.20791042 0.21526928]\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.444069Z",
     "start_time": "2026-03-04T15:10:47.432126900Z"
    }
   },
   "source": [
    "# Passing a quantity with units raises an error\n",
    "try:\n",
    "    u.lax.logistic(jnp.array([1.0, 2.0]) * u.volt)\n",
    "except Exception as e:\n",
    "    print(type(e).__name__, ':', e)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TypeError : logistic requires a dimensionless \"x\" when \"unit_to_scale\" is not provided. Got Quantity(unit=V, dim=m^2 kg s^-3 A^-1). Pass \"unit_to_scale=<Unit>\" to scale before applying logistic, or convert explicitly to a dimensionless value first.\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LAX Linear Algebra\n",
    "\n",
    "`brainunit.lax` also provides low-level linear algebra primitives."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.012711200Z",
     "start_time": "2026-03-04T15:10:47.444069Z"
    }
   },
   "source": [
    "# Cholesky decomposition (positive definite matrix)\n",
    "M = jnp.array([[4., 2.], [2., 3.]]) * u.meter2\n",
    "L = u.lax.cholesky(M)\n",
    "print('cholesky:')\n",
    "print(L)  # unit: m (sqrt of m^2)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cholesky:\n",
      "[[2.        0.       ]\n",
      " [1.        1.41421354]] m\n"
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.383906700Z",
     "start_time": "2026-03-04T15:10:48.262118300Z"
    }
   },
   "source": [
    "# QR decomposition\n",
    "A = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.newton\n",
    "Q, R = u.lax.qr(A)\n",
    "print('Q (orthogonal, dimensionless):')\n",
    "print(Q)\n",
    "print('R (upper triangular, with unit):')\n",
    "print(R)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q (orthogonal, dimensionless):\n",
      "[[-0.1690309   0.8970853   0.40824777]\n",
      " [-0.5070926   0.27602604 -0.8164966 ]\n",
      " [-0.84515435 -0.34503248  0.40824845]]\n",
      "R (upper triangular, with unit):\n",
      "[[-5.91607952 -7.4373579]\n",
      " [ 0.         0.82807958]\n",
      " [ 0.         0.       ]] N\n"
     ]
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.583705600Z",
     "start_time": "2026-03-04T15:10:48.407221300Z"
    }
   },
   "source": [
    "# Triangular solve: solve L @ x = b  where L is lower triangular\n",
    "L = jnp.array([[2., 0.], [1., 3.]]) * u.ohm\n",
    "b = jnp.array([4., 7.]) * u.volt\n",
    "\n",
    "x = u.lax.triangular_solve(L, b, left_side=True, lower=True)\n",
    "print('triangular_solve:', x)  # V / ohm = A"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "triangular_solve: [2.        1.66666675] V\n"
     ]
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Array Creation"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.949435400Z",
     "start_time": "2026-03-04T15:10:48.740237900Z"
    }
   },
   "source": [
    "# Create an index array (always dimensionless)\n",
    "idx = u.lax.iota(jnp.int32, 5)\n",
    "print('iota:', idx)\n",
    "\n",
    "# zeros_like_array\n",
    "template = jnp.array([1., 2., 3.]) * u.volt\n",
    "print('zeros_like:', u.lax.zeros_like_array(template))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iota: [0 1 2 3 4]\n",
      "zeros_like: [0. 0. 0.] V\n"
     ]
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Category | Functions | Unit Behavior |\n",
    "|----------|----------|---------------|\n",
    "| **Slicing** | `slice`, `dynamic_slice`, `slice_in_dim` | Keep unit |\n",
    "| **Sorting** | `sort`, `top_k` | Keep unit |\n",
    "| **Cumulative** | `cumsum`, `cummin`, `cummax` | Keep unit |\n",
    "| **Layout** | `pad`, `clamp`, `broadcast`, `neg` | Keep unit |\n",
    "| **Arithmetic** | `mul`, `div`, `integer_pow`, `rsqrt` | Change unit |\n",
    "| **Products** | `dot_general`, `batch_matmul`, `conv` | Change unit |\n",
    "| **Comparisons** | `eq`, `ne`, `lt`, `le`, `gt`, `ge` | Remove unit |\n",
    "| **Trig/Special** | `asin`, `logistic`, `erf`, `bessel_*` | Require unitless |\n",
    "| **Linalg** | `cholesky`, `qr`, `svd`, `triangular_solve` | Varies |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
