{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sparse Matrices with Units\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainunit/blob/master/docs/sparse_matrices/sparse_matrices.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainunit/blob/master/docs/sparse_matrices/sparse_matrices.ipynb)\n",
    "\n",
    "`brainunit.sparse` provides unit-aware sparse matrix classes built on top of JAX's sparse\n",
    "representations. Sparse matrices store only non-zero elements, saving memory for large\n",
    "matrices that are mostly zeros (common in scientific computing, e.g., connectivity matrices).\n",
    "\n",
    "Available formats:\n",
    "- **CSR** (Compressed Sparse Row) — efficient for row slicing and matrix-vector products\n",
    "- **CSC** (Compressed Sparse Column) — efficient for column slicing\n",
    "- **COO** (Coordinate) — efficient for constructing sparse matrices"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:12.839360800Z",
     "start_time": "2026-03-04T15:11:11.405758200Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "import jax.numpy as jnp"
   ],
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Sparse Matrices from Dense\n",
    "\n",
    "The simplest way to create a sparse matrix is from a dense `Quantity` array."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:13.486142700Z",
     "start_time": "2026-03-04T15:11:12.839360800Z"
    }
   },
   "source": [
    "# A sparse matrix with units\n",
    "dense = jnp.array([\n",
    "    [1., 0., 2.],\n",
    "    [0., 3., 0.],\n",
    "    [4., 0., 5.]\n",
    "]) * u.volt\n",
    "\n",
    "print('Dense matrix:')\n",
    "print(dense)\n",
    "print('Non-zero elements:', 5, 'out of', 9)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dense matrix:\n",
      "[[1. 0. 2.]\n",
      " [0. 3. 0.]\n",
      " [4. 0. 5.]] V\n",
      "Non-zero elements: 5 out of 9\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CSR (Compressed Sparse Row)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:13.763182100Z",
     "start_time": "2026-03-04T15:11:13.527008700Z"
    }
   },
   "source": [
    "csr = u.sparse.csr_fromdense(dense)\n",
    "print('CSR:', csr)\n",
    "print('Shape:', csr.shape)\n",
    "print('Number of stored elements (nse):', csr.nse)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSR: CSR(float32[3, 3], nse=5)\n",
      "Shape: (3, 3)\n",
      "Number of stored elements (nse): 5\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:13.860643400Z",
     "start_time": "2026-03-04T15:11:13.763182100Z"
    }
   },
   "source": [
    "# Convert back to dense to verify\n",
    "print('Back to dense:')\n",
    "print(csr.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Back to dense:\n",
      "[[1. 0. 2.]\n",
      " [0. 3. 0.]\n",
      " [4. 0. 5.]] V\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CSC (Compressed Sparse Column)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:13.924495Z",
     "start_time": "2026-03-04T15:11:13.860643400Z"
    }
   },
   "source": [
    "csc = u.sparse.csc_fromdense(dense)\n",
    "print('CSC:', csc)\n",
    "print('CSC todense:')\n",
    "print(csc.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSC: CSC(float32[3, 3], nse=5)\n",
      "CSC todense:\n",
      "[[1. 0. 2.]\n",
      " [0. 3. 0.]\n",
      " [4. 0. 5.]] V\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### COO (Coordinate)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.088872300Z",
     "start_time": "2026-03-04T15:11:13.925488500Z"
    }
   },
   "source": [
    "coo = u.sparse.coo_fromdense(dense)\n",
    "print('COO:', coo)\n",
    "print('COO todense:')\n",
    "print(coo.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "COO: COO(float32[3, 3], nse=5)\n",
      "COO todense:\n",
      "[[1. 0. 2.]\n",
      " [0. 3. 0.]\n",
      " [4. 0. 5.]] V\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matrix-Vector Products\n",
    "\n",
    "The key operation for sparse matrices is the matrix-vector product (`@` operator).\n",
    "Units multiply just like with dense matrices."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.311967500Z",
     "start_time": "2026-03-04T15:11:14.089889200Z"
    }
   },
   "source": [
    "# Sparse matrix (V) @ vector (A) = vector (V*A = W)\n",
    "v = jnp.array([1., 2., 3.]) * u.ampere\n",
    "\n",
    "print('CSR @ v:', csr @ v)  # V * A = W\n",
    "print('COO @ v:', coo @ v)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSR @ v: [ 7.  6. 19.] W\n",
      "COO @ v: [ 7.  6. 19.] W\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.533375800Z",
     "start_time": "2026-03-04T15:11:14.311967500Z"
    }
   },
   "source": [
    "# Physical example: conductance matrix @ voltage = current\n",
    "# G (siemens) @ V (volts) = I (amperes)\n",
    "G_dense = jnp.array([\n",
    "    [0.5, -0.1, 0.0],\n",
    "    [-0.1, 0.3, -0.2],\n",
    "    [0.0, -0.2, 0.4]\n",
    "]) * u.siemens\n",
    "\n",
    "G_sparse = u.sparse.csr_fromdense(G_dense)\n",
    "voltages = jnp.array([10., 5., 2.]) * u.volt\n",
    "\n",
    "currents = G_sparse @ voltages\n",
    "print('Node currents:', currents)  # siemens * volt = ampere"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node currents: [ 4.5         0.09999999 -0.19999999] A\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Arithmetic Operations\n",
    "\n",
    "Sparse matrices support basic arithmetic with unit tracking."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.588896500Z",
     "start_time": "2026-03-04T15:11:14.534851Z"
    }
   },
   "source": [
    "# Scalar multiplication\n",
    "doubled = csr * 2\n",
    "print('CSR * 2:')\n",
    "print(doubled.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSR * 2:\n",
      "[[ 2.  0.  4.]\n",
      " [ 0.  6.  0.]\n",
      " [ 8.  0. 10.]] V\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.657948100Z",
     "start_time": "2026-03-04T15:11:14.588896500Z"
    }
   },
   "source": [
    "# Addition of same-format sparse matrices\n",
    "summed = csr + csr\n",
    "print('CSR + CSR:')\n",
    "print(summed.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CSR + CSR:\n",
      "[[ 2.  0.  4.]\n",
      " [ 0.  6.  0.]\n",
      " [ 8.  0. 10.]] V\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modifying Data with `with_data()`\n",
    "\n",
    "The `with_data()` method creates a new sparse matrix with the same sparsity pattern \n",
    "but different values."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:14.676202900Z",
     "start_time": "2026-03-04T15:11:14.657948100Z"
    }
   },
   "source": [
    "# Scale all stored values\n",
    "scaled = csr.with_data(csr.data * 10)\n",
    "print('Original:')\n",
    "print(csr.todense())\n",
    "print('Scaled by 10:')\n",
    "print(scaled.todense())"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original:\n",
      "[[1. 0. 2.]\n",
      " [0. 3. 0.]\n",
      " [4. 0. 5.]] V\n",
      "Scaled by 10:\n",
      "[[10.  0. 20.]\n",
      " [ 0. 30.  0.]\n",
      " [40.  0. 50.]] V\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Example: Sparse Connectivity Matrix\n",
    "\n",
    "In neural network simulations, connectivity between neurons is often sparse."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2026-03-04T15:11:15.046342200Z",
     "start_time": "2026-03-04T15:11:14.677631100Z"
    }
   },
   "source": [
    "# Create a sparse weight matrix (most connections are zero)\n",
    "n_neurons = 5\n",
    "weights_dense = jnp.array([\n",
    "    [0.0, 0.5, 0.0, 0.0, 0.3],\n",
    "    [0.0, 0.0, 0.8, 0.0, 0.0],\n",
    "    [0.2, 0.0, 0.0, 0.6, 0.0],\n",
    "    [0.0, 0.0, 0.0, 0.0, 0.4],\n",
    "    [0.1, 0.0, 0.0, 0.0, 0.0]\n",
    "]) * u.siemens  # synaptic conductance\n",
    "\n",
    "W = u.sparse.csr_fromdense(weights_dense)\n",
    "print('Weight matrix:', W)\n",
    "print('Sparsity:', 1.0 - W.nse / (n_neurons * n_neurons), '(fraction of zeros)')\n",
    "\n",
    "# Compute synaptic currents: I = W @ V\n",
    "membrane_voltages = jnp.array([-70., -65., -80., -55., -60.]) * u.mV\n",
    "synaptic_currents = W @ membrane_voltages\n",
    "print('Synaptic currents:', synaptic_currents)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight matrix: CSR(float32[5, 5], nse=7)\n",
      "Sparsity: 0.72 (fraction of zeros)\n",
      "Synaptic currents: [-50.5 -64.  -47.  -24.   -7. ] mA\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "| Format | Create | Best For |\n",
    "|--------|--------|----------|\n",
    "| `CSR` | `csr_fromdense(dense)` | Row slicing, matrix-vector products |\n",
    "| `CSC` | `csc_fromdense(dense)` | Column slicing |\n",
    "| `COO` | `coo_fromdense(dense)` | Building sparse matrices |\n",
    "\n",
    "| Operation | Syntax | Unit Behavior |\n",
    "|-----------|--------|---------------|\n",
    "| Matrix-vector | `sparse @ vector` | Units multiply |\n",
    "| Scalar multiply | `sparse * scalar` | Scales values |\n",
    "| Addition | `sparse + sparse` | Same unit required |\n",
    "| To dense | `sparse.todense()` | Preserves unit |\n",
    "| Replace data | `sparse.with_data(new)` | New data determines unit |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
