{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear Algebra Functions\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/mathematical_functions/linalg_functions.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/mathematical_functions/linalg_functions.ipynb)\n",
    "\n",
    "`brainunit.linalg` provides unit-aware linear algebra functions. These functions automatically\n",
    "track how units propagate through linear algebra operations:\n",
    "\n",
    "- **Keeping unit**: `trace`, `diagonal`, `norm`, `matrix_transpose`\n",
    "- **Changing unit**: `dot`, `matmul`, `cross`, `det`, `inv`, `solve`, `cholesky`, `kron`, `matrix_power`\n",
    "- **Removing unit**: `eig`, `svd`, `qr`, `cond`, `matrix_rank`, `slogdet`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.574876300Z",
     "start_time": "2026-03-04T15:10:45.616495Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Keep Unit\n",
    "\n",
    "These functions preserve the input unit in the output."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `trace` — Sum of diagonal elements"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:46.937801100Z",
     "start_time": "2026-03-04T15:10:46.575880700Z"
    }
   },
   "source": [
    "A = jnp.array([[1., 2.], [3., 4.]]) * u.volt\n",
    "print('A:')\n",
    "print(A)\n",
    "print('trace(A):', u.linalg.trace(A))  # 1 + 4 = 5 V"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "A:\n",
      "[[1. 2.]\n",
      " [3. 4.]] V\n",
      "trace(A): 5. V\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `diagonal` — Extract diagonal elements"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.031279300Z",
     "start_time": "2026-03-04T15:10:46.938798600Z"
    }
   },
   "source": [
    "print('diagonal(A):', u.linalg.diagonal(A))  # [1, 4] V"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "diagonal(A): [1. 4.] V\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `norm` — Vector and matrix norms"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.158509600Z",
     "start_time": "2026-03-04T15:10:47.031279300Z"
    }
   },
   "source": [
    "v = jnp.array([3., 4.]) * u.meter\n",
    "print('v:', v)\n",
    "print('L2 norm:', u.linalg.norm(v))          # sqrt(9+16) = 5 m\n",
    "print('L1 norm:', u.linalg.norm(v, ord=1))   # 3+4 = 7 m\n",
    "print('Inf norm:', u.linalg.norm(v, ord=jnp.inf))  # max(3,4) = 4 m"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v: [3. 4.] m\n",
      "L2 norm: 5. m\n",
      "L1 norm: 7. m\n",
      "Inf norm: 4. m\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.275417Z",
     "start_time": "2026-03-04T15:10:47.158509600Z"
    }
   },
   "source": [
    "# Matrix norms\n",
    "print('Frobenius norm:', u.linalg.norm(A))               # sqrt(1+4+9+16) V\n",
    "print('Matrix L2 norm:', u.linalg.matrix_norm(A, ord=2)) # largest singular value"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Frobenius norm: 5.477226 V\n",
      "Matrix L2 norm: 5.464986 V\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `matrix_transpose` — Transpose"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.367642800Z",
     "start_time": "2026-03-04T15:10:47.278435600Z"
    }
   },
   "source": [
    "B = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.ampere\n",
    "print('B shape:', B.shape)\n",
    "print('B^T shape:', u.linalg.matrix_transpose(B).shape)\n",
    "print('B^T:')\n",
    "print(u.linalg.matrix_transpose(B))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "B shape: (2, 3)\n",
      "B^T shape: (3, 2)\n",
      "B^T:\n",
      "[[1. 4.]\n",
      " [2. 5.]\n",
      " [3. 6.]] A\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Change Unit\n",
    "\n",
    "These functions produce outputs with different units than the inputs,\n",
    "following the mathematical rules of the operation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `dot`, `matmul` — Dot and matrix products\n",
    "\n",
    "When multiplying quantities, units multiply: `[meter] @ [second] → [meter * second]`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.423301600Z",
     "start_time": "2026-03-04T15:10:47.373596Z"
    }
   },
   "source": [
    "# Vector dot product\n",
    "a = jnp.array([1., 2., 3.]) * u.meter\n",
    "b = jnp.array([4., 5., 6.]) * u.newton\n",
    "print('a . b:', u.linalg.dot(a, b))  # 1*4 + 2*5 + 3*6 = 32 m*N (= 32 J)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a . b: 32. J\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.463650900Z",
     "start_time": "2026-03-04T15:10:47.424315Z"
    }
   },
   "source": [
    "# Matrix-vector multiplication\n",
    "M = jnp.array([[1., 0.], [0., 2.]]) * u.ohm\n",
    "i = jnp.array([3., 4.]) * u.ampere\n",
    "print('M @ i:', u.linalg.matmul(M, i))  # [3, 8] V  (Ohm's law: V = IR)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M @ i: [3. 8.] V\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:47.887475200Z",
     "start_time": "2026-03-04T15:10:47.464660600Z"
    }
   },
   "source": [
    "# Matrix-matrix multiplication: units multiply\n",
    "P = jnp.array([[1., 2.], [3., 4.]]) * u.meter\n",
    "Q = jnp.array([[5., 6.], [7., 8.]]) * u.meter\n",
    "print('P @ Q:')\n",
    "print(u.linalg.matmul(P, Q))  # m * m = m^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P @ Q:\n",
      "[[19. 22.]\n",
      " [43. 50.]] m^2\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `cross` — Cross product"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.101005900Z",
     "start_time": "2026-03-04T15:10:47.935199100Z"
    }
   },
   "source": [
    "# Force = charge * (velocity x B_field)\n",
    "velocity = jnp.array([1., 0., 0.]) * u.meter / u.second\n",
    "b_field = jnp.array([0., 0., 1.]) * u.tesla\n",
    "print('v x B:', u.linalg.cross(velocity, b_field))  # [0, -1, 0] m*T/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v x B: [ 0. -1.  0.] m * T / s\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `det` — Determinant\n",
    "\n",
    "For an NxN matrix with unit `u`, the determinant has unit `u^N`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.161050200Z",
     "start_time": "2026-03-04T15:10:48.101993200Z"
    }
   },
   "source": [
    "M2 = jnp.array([[2., 1.], [1., 3.]]) * u.meter\n",
    "print('det(M2):', u.linalg.det(M2))  # 2*3 - 1*1 = 5 m^2"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "det(M2): 5. m^2\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.238338200Z",
     "start_time": "2026-03-04T15:10:48.162050100Z"
    }
   },
   "source": [
    "# 3x3 determinant: unit^3\n",
    "M3 = jnp.array([[1., 0., 2.], [0., 1., 0.], [2., 0., 1.]]) * u.second\n",
    "print('det(M3):', u.linalg.det(M3))  # s^3"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "det(M3): -3. s^3\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `inv` — Matrix inverse\n",
    "\n",
    "The inverse of a matrix with unit `u` has unit `1/u`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.411736400Z",
     "start_time": "2026-03-04T15:10:48.239338600Z"
    }
   },
   "source": [
    "R = jnp.array([[2., 1.], [1., 3.]]) * u.ohm\n",
    "print('R:')\n",
    "print(R)\n",
    "print('inv(R):')\n",
    "print(u.linalg.inv(R))  # unit: 1/ohm = siemens"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "R:\n",
      "[[2. 1.]\n",
      " [1. 3.]] ohm\n",
      "inv(R):\n",
      "[[ 0.60000002 -0.2]\n",
      " [-0.2  0.40000001]] S\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.455585900Z",
     "start_time": "2026-03-04T15:10:48.412738300Z"
    }
   },
   "source": [
    "# Verify: R @ inv(R) should be identity (dimensionless)\n",
    "R_inv = u.linalg.inv(R)\n",
    "print('R @ R_inv:')\n",
    "print(u.linalg.matmul(R, R_inv))  # approximately identity"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "R @ R_inv:\n",
      "[[1.0000000e+00 0.0000000e+00]\n",
      " [1.4901161e-08 1.0000000e+00]]\n"
     ]
    }
   ],
   "execution_count": 14
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `solve` — Solve linear system Ax = b\n",
    "\n",
    "If A has unit `u_A` and b has unit `u_b`, then x has unit `u_b / u_A`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.785997700Z",
     "start_time": "2026-03-04T15:10:48.456859800Z"
    }
   },
   "source": [
    "# Physical example: Ohm's law in a circuit network\n",
    "# R * I = V  -->  I = solve(R, V)\n",
    "R_matrix = jnp.array([[10., 2.], [2., 8.]]) * u.ohm\n",
    "V_vector = jnp.array([12., 6.]) * u.volt\n",
    "\n",
    "I_solution = u.linalg.solve(R_matrix, V_vector)\n",
    "print('Current solution:', I_solution)  # volt / ohm = ampere"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current solution: [1.10526311 0.47368422] A\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `cholesky` — Cholesky decomposition\n",
    "\n",
    "For a matrix with unit `u`, the Cholesky factor has unit `sqrt(u)`,\n",
    "since `L @ L^T = A` and `sqrt(u) * sqrt(u) = u`."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:48.996612800Z",
     "start_time": "2026-03-04T15:10:48.830065800Z"
    }
   },
   "source": [
    "# Positive definite matrix\n",
    "S = jnp.array([[4., 2.], [2., 5.]]) * u.meter2\n",
    "L = u.linalg.cholesky(S)\n",
    "print('S:')\n",
    "print(S)\n",
    "print('cholesky(S):')\n",
    "print(L)  # unit: m (sqrt of m^2)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S:\n",
      "[[4. 2.]\n",
      " [2. 5.]] m^2\n",
      "cholesky(S):\n",
      "[[2. 0.]\n",
      " [1. 2.]] m\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `kron` — Kronecker product"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:49.287883800Z",
     "start_time": "2026-03-04T15:10:49.033177600Z"
    }
   },
   "source": [
    "X = jnp.array([[1., 0.], [0., 1.]]) * u.meter\n",
    "Y = jnp.array([[2., 3.], [4., 5.]]) * u.second\n",
    "print('kron(X, Y):')\n",
    "print(u.linalg.kron(X, Y))  # m * s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kron(X, Y):\n",
      "[[2. 3. 0. 0.]\n",
      " [4. 5. 0. 0.]\n",
      " [0. 0. 2. 3.]\n",
      " [0. 0. 4. 5.]] m * s\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `matrix_power` — Raise matrix to integer power"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:49.513263Z",
     "start_time": "2026-03-04T15:10:49.330786900Z"
    }
   },
   "source": [
    "T = jnp.array([[1., 1.], [0., 1.]]) * u.meter\n",
    "print('T^2:', u.linalg.matrix_power(T, 2))   # m^2\n",
    "print('T^3:', u.linalg.matrix_power(T, 3))   # m^3"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "T^2: [[1. 2.]\n",
      " [0. 1.]] m^2\n",
      "T^3: [[1. 3.]\n",
      " [0. 1.]] m^3\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `lstsq` — Least-squares solution"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.303028200Z",
     "start_time": "2026-03-04T15:10:49.694483200Z"
    }
   },
   "source": [
    "# Fitting y = ax + b  where x is in seconds and y in meters\n",
    "# Design matrix has columns [x, 1]\n",
    "x_data = jnp.array([0., 1., 2., 3., 4.]) * u.second\n",
    "y_data = jnp.array([1.1, 2.9, 5.2, 6.8, 9.1]) * u.meter\n",
    "\n",
    "# Build design matrix (must be same unit or unitless)\n",
    "A_design = jnp.stack([x_data.mantissa, jnp.ones(5)], axis=1) * u.second\n",
    "result = u.linalg.lstsq(A_design, y_data)\n",
    "print('Coefficients:', result[0])  # [slope, intercept] in m/s"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Coefficients: [1.99000037 1.03999949] m / s\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `pinv` — Pseudoinverse"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.703255900Z",
     "start_time": "2026-03-04T15:10:50.586285300Z"
    }
   },
   "source": [
    "M_rect = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.volt\n",
    "print('M shape:', M_rect.shape)\n",
    "print('pinv(M) shape:', u.linalg.pinv(M_rect).shape)\n",
    "print('pinv(M):')\n",
    "print(u.linalg.pinv(M_rect))  # unit: 1/V"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "M shape: (3, 2)\n",
      "pinv(M) shape: (2, 3)\n",
      "pinv(M):\n",
      "[[-1.33333194 -0.33333257  0.66666567]\n",
      " [ 1.08333218  0.33333272 -0.41666582]] 1 / V\n"
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Functions That Remove Unit (Decompositions)\n",
    "\n",
    "Decompositions separate a matrix into dimensionless factors (orthogonal matrices, singular values, etc.).\n",
    "The singular values / eigenvalues carry the unit information."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `svd` — Singular Value Decomposition"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:50.809296600Z",
     "start_time": "2026-03-04T15:10:50.704286800Z"
    }
   },
   "source": [
    "M_svd = jnp.array([[1., 2.], [3., 4.]]) * u.meter\n",
    "U, S, Vt = u.linalg.svd(M_svd)\n",
    "print('U (dimensionless):', U)  \n",
    "print('S (with unit):', S)  # singular values carry the unit\n",
    "print('Vt (dimensionless):', Vt)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "U (dimensionless): [[-0.40455365 -0.9145144 ]\n",
      " [-0.9145144   0.40455353]]\n",
      "S (with unit): [5.46498537 0.36596605] m\n",
      "Vt (dimensionless): [[-0.5760485  -0.81741554]\n",
      " [ 0.81741554 -0.5760485 ]]\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `eig` / `eigh` — Eigenvalue decomposition"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:51.297364500Z",
     "start_time": "2026-03-04T15:10:50.849627200Z"
    }
   },
   "source": [
    "# Symmetric matrix (use eigh for better numerical stability)\n",
    "H = jnp.array([[2., 1.], [1., 3.]]) * u.joule\n",
    "eigenvalues, eigenvectors = u.linalg.eigh(H)\n",
    "print('Eigenvalues:', eigenvalues)    # J\n",
    "print('Eigenvectors (dimensionless):')\n",
    "print(eigenvectors)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eigenvalues: [1.38196611 3.61803389] J\n",
      "Eigenvectors (dimensionless):\n",
      "[[-0.85065085  0.52573115]\n",
      " [ 0.52573115  0.85065085]]\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `qr` — QR decomposition"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:51.532137900Z",
     "start_time": "2026-03-04T15:10:51.300451700Z"
    }
   },
   "source": [
    "M_qr = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) * u.meter\n",
    "Q_mat, R_mat = u.linalg.qr(M_qr)\n",
    "print('Q (dimensionless):')\n",
    "print(Q_mat)\n",
    "print('R (with unit):')\n",
    "print(R_mat)  # upper triangular, carries the unit"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q (dimensionless):\n",
      "[[-0.1690309   0.8970853 ]\n",
      " [-0.5070926   0.27602604]\n",
      " [-0.84515435 -0.34503248]]\n",
      "R (with unit):\n",
      "[[-5.91607952 -7.4373579]\n",
      " [ 0.         0.82807958]] m\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `cond` — Condition number"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:52.432038600Z",
     "start_time": "2026-03-04T15:10:51.617363500Z"
    }
   },
   "source": [
    "# Condition number is always dimensionless (ratio of singular values)\n",
    "well_cond = jnp.array([[1., 0.], [0., 1.]]) * u.meter\n",
    "ill_cond = jnp.array([[1., 1.], [1., 1.0001]]) * u.meter\n",
    "print('Well-conditioned:', u.linalg.cond(well_cond))\n",
    "print('Ill-conditioned:', u.linalg.cond(ill_cond))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Well-conditioned: 1.0\n",
      "Ill-conditioned: 39949.367\n"
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `matrix_rank`"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.047505400Z",
     "start_time": "2026-03-04T15:10:52.876815100Z"
    }
   },
   "source": [
    "full_rank = jnp.array([[1., 2.], [3., 4.]]) * u.meter\n",
    "rank_def = jnp.array([[1., 2.], [2., 4.]]) * u.meter  # row 2 = 2 * row 1\n",
    "print('Full rank matrix rank:', u.linalg.matrix_rank(full_rank))\n",
    "print('Rank-deficient matrix rank:', u.linalg.matrix_rank(rank_def))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full rank matrix rank: 2\n",
      "Rank-deficient matrix rank: 1\n"
     ]
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `slogdet` — Sign and log of determinant"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.243565600Z",
     "start_time": "2026-03-04T15:10:53.092652700Z"
    }
   },
   "source": [
    "M_det = jnp.array([[2., 1.], [1., 3.]]) * u.meter\n",
    "sign, logabsdet = u.linalg.slogdet(M_det)\n",
    "print('Sign:', sign)\n",
    "print('Log |det|:', logabsdet)  # dimensionless"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sign: 1.0\n",
      "Log |det|: 1.609438\n"
     ]
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Example: Solving a Physical System\n",
    "\n",
    "A simple resistor network with Kirchhoff's laws leads to a linear system."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:10:53.670187900Z",
     "start_time": "2026-03-04T15:10:53.251558100Z"
    }
   },
   "source": [
    "# Three-node resistor network\n",
    "# Conductance matrix G (siemens) and current sources I (ampere)\n",
    "G = jnp.array([\n",
    "    [0.3, -0.1, -0.1],\n",
    "    [-0.1, 0.4, -0.2],\n",
    "    [-0.1, -0.2, 0.5]\n",
    "]) * u.siemens\n",
    "\n",
    "I_source = jnp.array([1.0, 0.0, -0.5]) * u.ampere\n",
    "\n",
    "# Solve for node voltages: G @ V = I  -->  V = solve(G, I)\n",
    "V_nodes = u.linalg.solve(G, I_source)\n",
    "print('Node voltages:', V_nodes)  # ampere / siemens = volt\n",
    "\n",
    "# Verify: G @ V should equal I\n",
    "print('Verification (G @ V):', u.linalg.matmul(G, V_nodes))"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node voltages: [3.71428561 1.00000012 0.14285721] V\n",
      "Verification (G @ V): [ 9.9999994e-01  3.3101866e-08 -5.0000000e-01] A\n"
     ]
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Function | Unit Behavior | Example |\n",
    "|----------|--------------|--------|\n",
    "| `trace`, `diagonal`, `norm` | Keeps input unit | `trace([m]) → m` |\n",
    "| `dot`, `matmul`, `kron` | Multiplies units | `[m] @ [s] → m*s` |\n",
    "| `det` | Unit^N for NxN | `det([m]_{2x2}) → m^2` |\n",
    "| `inv`, `pinv` | Reciprocal unit | `inv([ohm]) → 1/ohm` |\n",
    "| `solve(A, b)` | b_unit / A_unit | `solve([ohm], [V]) → A` |\n",
    "| `cholesky` | sqrt(unit) | `chol([m^2]) → m` |\n",
    "| `svd`, `eig`, `qr` | Factors carry units | `S` has unit, `U/V` dimensionless |\n",
    "| `cond`, `matrix_rank` | Dimensionless | Always unitless output |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
