CSR#

class brainunit.sparse.CSR(args, *, shape)[source]#

Unit-aware Compressed Sparse Row (CSR) matrix.

Stores a 2-D sparse matrix in CSR format with optional physical-unit support via Quantity.

Parameters:
  • args (tuple of (data, indices, indptr)) – data contains the non-zero values (jax.Array or Quantity), indices contains the column indices, and indptr contains the row pointer array.

  • shape (tuple of int) – The (nrows, ncols) shape of the matrix.

data#

Non-zero values of shape (nse,).

Type:

jax.Array or Quantity

indices#

Column indices of shape (nse,).

Type:

jax.Array

indptr#

Row pointer array of shape (nrows + 1,).

Type:

jax.Array

shape#

Shape of the matrix (nrows, ncols).

Type:

tuple of int

nse#

Number of stored elements.

Type:

int

dtype#

Data type of the stored values.

Type:

dtype

See also

CSC

Unit-aware Compressed Sparse Column matrix.

csr_fromdense

Create a CSR matrix from a dense array.

csr_todense

Convert a CSR matrix to a dense array.

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0., 2.], [0., 0., 3.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> csr.shape
(2, 3)
>>> csr.todense()
Array([[1., 0., 2.],
       [0., 0., 3.]], dtype=float32)
todense()[source]#

Convert this CSR matrix to a dense array.

Returns:

Dense 2-D array equivalent to this sparse matrix.

Return type:

jax.Array or Quantity

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[0., 3.], [4., 0.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> csr.todense()
Array([[0., 3.],
       [4., 0.]], dtype=float32)
with_data(data)[source]#

Create a new CSR matrix with the same sparsity structure but different data.

Parameters:

data (Array | saiunit.Quantity) – New non-zero values. Must have the same shape, dtype, and unit as the current self.data.

Returns:

A new CSR matrix sharing the same indices and indptr but holding the provided data.

Return type:

CSR

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.sparse as susparse
>>> dense = jnp.array([[1., 0.], [0., 2.]])
>>> csr = susparse.CSR.fromdense(dense)
>>> new_csr = csr.with_data(csr.data * 5)
>>> new_csr.todense()
Array([[ 5., 0.],
       [ 0., 10.]], dtype=float32)