brainevent.csrmm

Contents

brainevent.csrmm#

brainevent.csrmm = <NameScope(brainevent.csrmm)>#

Product of a CSR sparse matrix and a dense matrix.

Computes C = A @ B (or C = A.T @ B when transpose=True) where A is stored in Compressed Sparse Row format and B is a dense matrix.

The function supports physical units via brainunit.

Parameters:
  • data (Array | ndarray | Quantity | Number) – Non-zero values of the CSR matrix. Shape (nse,) for heterogeneous weights or (1,) for a single homogeneous weight.

  • indices (Array | ndarray) – Column indices of the non-zero elements. Shape (nse,) with integer dtype.

  • indptr (Array | ndarray) – Row index pointer array. Shape (shape[0] + 1,) and same dtype as indices.

  • B (Array | ndarray | Quantity | Number) – Dense matrix. Shape (shape[0], cols) when transpose=True or (shape[1], cols) when transpose=False.

  • shape (Tuple[int, int]) – Two-element tuple (m, k) giving the logical shape of the sparse matrix A.

  • transpose (bool) – If True, transpose A before multiplication. Default is False.

  • backend (str | None) – Compute backend. One of 'numba', 'pallas', or None (auto-select). Default is None.

Returns:

C – Result matrix. Shape (shape[1], cols) when transpose=True or (shape[0], cols) when transpose=False.

Return type:

Array | ndarray | Quantity | Number

See also

csrmv

CSR matrix–vector multiplication.

binary_csrmm

Event-driven (binary) CSR matrix–matrix multiplication.

Notes

Custom JVP and transpose rules are provided for automatic differentiation with respect to data and B.

Mathematically, the non-transposed operation computes:

C[i, l] = sum_{j in nz(i)} A[i, j] * B[j, l]

where nz(i) denotes the set of column indices with non-zero entries in row i of the CSR matrix.

When transpose=True, the transposed operation computes:

C[j, l] = sum_{i in nz_col(j)} A[i, j] * B[i, l]

where nz_col(j) denotes the set of row indices with non-zero entries in column j.

For homogeneous weights (data of shape (1,)), A[i, j] equals the constant data[0] for all structural non-zero positions.

References

Examples

>>> import jax.numpy as jnp
>>> from brainevent._csr.float import csrmm
>>> data = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32)
>>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32)
>>> B = jnp.array([[1.0, 0.5],
...                [2.0, 1.5],
...                [3.0, 2.5]])
>>> csrmm(data, indices, indptr, B, shape=(2, 3))