Sparse Matrices with Units#

Colab Open in Kaggle

brainunit.sparse provides unit-aware sparse matrix classes built on top of JAX’s sparse representations. Sparse matrices store only non-zero elements, saving memory for large matrices that are mostly zeros (common in scientific computing, e.g., connectivity matrices).

Available formats:

  • CSR (Compressed Sparse Row) — efficient for row slicing and matrix-vector products

  • CSC (Compressed Sparse Column) — efficient for column slicing

  • COO (Coordinate) — efficient for constructing sparse matrices

import brainunit as u
import jax.numpy as jnp

Creating Sparse Matrices from Dense#

The simplest way to create a sparse matrix is from a dense Quantity array.

# A sparse matrix with units
dense = jnp.array([
    [1., 0., 2.],
    [0., 3., 0.],
    [4., 0., 5.]
]) * u.volt

print('Dense matrix:')
print(dense)
print('Non-zero elements:', 5, 'out of', 9)
Dense matrix:
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]] V
Non-zero elements: 5 out of 9

CSR (Compressed Sparse Row)#

csr = u.sparse.csr_fromdense(dense)
print('CSR:', csr)
print('Shape:', csr.shape)
print('Number of stored elements (nse):', csr.nse)
CSR: CSR(float32[3, 3], nse=5)
Shape: (3, 3)
Number of stored elements (nse): 5
# Convert back to dense to verify
print('Back to dense:')
print(csr.todense())
Back to dense:
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]] V

CSC (Compressed Sparse Column)#

csc = u.sparse.csc_fromdense(dense)
print('CSC:', csc)
print('CSC todense:')
print(csc.todense())
CSC: CSC(float32[3, 3], nse=5)
CSC todense:
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]] V

COO (Coordinate)#

coo = u.sparse.coo_fromdense(dense)
print('COO:', coo)
print('COO todense:')
print(coo.todense())
COO: COO(float32[3, 3], nse=5)
COO todense:
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]] V

Matrix-Vector Products#

The key operation for sparse matrices is the matrix-vector product (@ operator). Units multiply just like with dense matrices.

# Sparse matrix (V) @ vector (A) = vector (V*A = W)
v = jnp.array([1., 2., 3.]) * u.ampere

print('CSR @ v:', csr @ v)  # V * A = W
print('COO @ v:', coo @ v)
CSR @ v: [ 7.  6. 19.] W
COO @ v: [ 7.  6. 19.] W
# Physical example: conductance matrix @ voltage = current
# G (siemens) @ V (volts) = I (amperes)
G_dense = jnp.array([
    [0.5, -0.1, 0.0],
    [-0.1, 0.3, -0.2],
    [0.0, -0.2, 0.4]
]) * u.siemens

G_sparse = u.sparse.csr_fromdense(G_dense)
voltages = jnp.array([10., 5., 2.]) * u.volt

currents = G_sparse @ voltages
print('Node currents:', currents)  # siemens * volt = ampere
Node currents: 
[ 4.5         0.09999999 -0.19999999] A

Arithmetic Operations#

Sparse matrices support basic arithmetic with unit tracking.

# Scalar multiplication
doubled = csr * 2
print('CSR * 2:')
print(doubled.todense())
CSR * 2:
[[ 2.  0.  4.]
 [ 0.  6.  0.]
 [ 8.  0. 10.]] V
# Addition of same-format sparse matrices
summed = csr + csr
print('CSR + CSR:')
print(summed.todense())
CSR + CSR:
[[ 2.  0.  4.]
 [ 0.  6.  0.]
 [ 8.  0. 10.]] V

Modifying Data with with_data()#

The with_data() method creates a new sparse matrix with the same sparsity pattern but different values.

# Scale all stored values
scaled = csr.with_data(csr.data * 10)
print('Original:')
print(csr.todense())
print('Scaled by 10:')
print(scaled.todense())
Original:
[[1. 0. 2.]
 [0. 3. 0.]
 [4. 0. 5.]] V
Scaled by 10:
[[10.  0. 20.]
 [ 0. 30.  0.]
 [40.  0. 50.]] V

Practical Example: Sparse Connectivity Matrix#

In neural network simulations, connectivity between neurons is often sparse.

# Create a sparse weight matrix (most connections are zero)
n_neurons = 5
weights_dense = jnp.array([
    [0.0, 0.5, 0.0, 0.0, 0.3],
    [0.0, 0.0, 0.8, 0.0, 0.0],
    [0.2, 0.0, 0.0, 0.6, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.4],
    [0.1, 0.0, 0.0, 0.0, 0.0]
]) * u.siemens  # synaptic conductance

W = u.sparse.csr_fromdense(weights_dense)
print('Weight matrix:', W)
print('Sparsity:', 1.0 - W.nse / (n_neurons * n_neurons), '(fraction of zeros)')

# Compute synaptic currents: I = W @ V
membrane_voltages = jnp.array([-70., -65., -80., -55., -60.]) * u.mV
synaptic_currents = W @ membrane_voltages
print('Synaptic currents:', synaptic_currents)
Weight matrix: CSR(float32[5, 5], nse=7)
Sparsity: 0.72 (fraction of zeros)
Synaptic currents: [-50.5 -64.  -47.  -24.   -7. ] mA

Summary#

Format

Create

Best For

CSR

csr_fromdense(dense)

Row slicing, matrix-vector products

CSC

csc_fromdense(dense)

Column slicing

COO

coo_fromdense(dense)

Building sparse matrices

Operation

Syntax

Unit Behavior

Matrix-vector

sparse @ vector

Units multiply

Scalar multiply

sparse * scalar

Scales values

Addition

sparse + sparse

Same unit required

To dense

sparse.todense()

Preserves unit

Replace data

sparse.with_data(new)

New data determines unit