Sparse Matrices with Units#
brainunit.sparse provides unit-aware sparse matrix classes built on top of JAX’s sparse
representations. Sparse matrices store only non-zero elements, saving memory for large
matrices that are mostly zeros (common in scientific computing, e.g., connectivity matrices).
Available formats:
CSR (Compressed Sparse Row) — efficient for row slicing and matrix-vector products
CSC (Compressed Sparse Column) — efficient for column slicing
COO (Coordinate) — efficient for constructing sparse matrices
import brainunit as u
import jax.numpy as jnp
Creating Sparse Matrices from Dense#
The simplest way to create a sparse matrix is from a dense Quantity array.
# A sparse matrix with units
dense = jnp.array([
[1., 0., 2.],
[0., 3., 0.],
[4., 0., 5.]
]) * u.volt
print('Dense matrix:')
print(dense)
print('Non-zero elements:', 5, 'out of', 9)
Dense matrix:
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]] V
Non-zero elements: 5 out of 9
CSR (Compressed Sparse Row)#
csr = u.sparse.csr_fromdense(dense)
print('CSR:', csr)
print('Shape:', csr.shape)
print('Number of stored elements (nse):', csr.nse)
CSR: CSR(float32[3, 3], nse=5)
Shape: (3, 3)
Number of stored elements (nse): 5
# Convert back to dense to verify
print('Back to dense:')
print(csr.todense())
Back to dense:
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]] V
CSC (Compressed Sparse Column)#
csc = u.sparse.csc_fromdense(dense)
print('CSC:', csc)
print('CSC todense:')
print(csc.todense())
CSC: CSC(float32[3, 3], nse=5)
CSC todense:
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]] V
COO (Coordinate)#
coo = u.sparse.coo_fromdense(dense)
print('COO:', coo)
print('COO todense:')
print(coo.todense())
COO: COO(float32[3, 3], nse=5)
COO todense:
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]] V
Matrix-Vector Products#
The key operation for sparse matrices is the matrix-vector product (@ operator).
Units multiply just like with dense matrices.
# Sparse matrix (V) @ vector (A) = vector (V*A = W)
v = jnp.array([1., 2., 3.]) * u.ampere
print('CSR @ v:', csr @ v) # V * A = W
print('COO @ v:', coo @ v)
CSR @ v: [ 7. 6. 19.] W
COO @ v: [ 7. 6. 19.] W
# Physical example: conductance matrix @ voltage = current
# G (siemens) @ V (volts) = I (amperes)
G_dense = jnp.array([
[0.5, -0.1, 0.0],
[-0.1, 0.3, -0.2],
[0.0, -0.2, 0.4]
]) * u.siemens
G_sparse = u.sparse.csr_fromdense(G_dense)
voltages = jnp.array([10., 5., 2.]) * u.volt
currents = G_sparse @ voltages
print('Node currents:', currents) # siemens * volt = ampere
Node currents:
[ 4.5 0.09999999 -0.19999999] A
Arithmetic Operations#
Sparse matrices support basic arithmetic with unit tracking.
# Scalar multiplication
doubled = csr * 2
print('CSR * 2:')
print(doubled.todense())
CSR * 2:
[[ 2. 0. 4.]
[ 0. 6. 0.]
[ 8. 0. 10.]] V
# Addition of same-format sparse matrices
summed = csr + csr
print('CSR + CSR:')
print(summed.todense())
CSR + CSR:
[[ 2. 0. 4.]
[ 0. 6. 0.]
[ 8. 0. 10.]] V
Modifying Data with with_data()#
The with_data() method creates a new sparse matrix with the same sparsity pattern
but different values.
# Scale all stored values
scaled = csr.with_data(csr.data * 10)
print('Original:')
print(csr.todense())
print('Scaled by 10:')
print(scaled.todense())
Original:
[[1. 0. 2.]
[0. 3. 0.]
[4. 0. 5.]] V
Scaled by 10:
[[10. 0. 20.]
[ 0. 30. 0.]
[40. 0. 50.]] V
Practical Example: Sparse Connectivity Matrix#
In neural network simulations, connectivity between neurons is often sparse.
# Create a sparse weight matrix (most connections are zero)
n_neurons = 5
weights_dense = jnp.array([
[0.0, 0.5, 0.0, 0.0, 0.3],
[0.0, 0.0, 0.8, 0.0, 0.0],
[0.2, 0.0, 0.0, 0.6, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.4],
[0.1, 0.0, 0.0, 0.0, 0.0]
]) * u.siemens # synaptic conductance
W = u.sparse.csr_fromdense(weights_dense)
print('Weight matrix:', W)
print('Sparsity:', 1.0 - W.nse / (n_neurons * n_neurons), '(fraction of zeros)')
# Compute synaptic currents: I = W @ V
membrane_voltages = jnp.array([-70., -65., -80., -55., -60.]) * u.mV
synaptic_currents = W @ membrane_voltages
print('Synaptic currents:', synaptic_currents)
Weight matrix: CSR(float32[5, 5], nse=7)
Sparsity: 0.72 (fraction of zeros)
Synaptic currents: [-50.5 -64. -47. -24. -7. ] mA
Summary#
Format |
Create |
Best For |
|---|---|---|
|
|
Row slicing, matrix-vector products |
|
|
Column slicing |
|
|
Building sparse matrices |
Operation |
Syntax |
Unit Behavior |
|---|---|---|
Matrix-vector |
|
Units multiply |
Scalar multiply |
|
Scales values |
Addition |
|
Same unit required |
To dense |
|
Preserves unit |
Replace data |
|
New data determines unit |