brainevent.csrmm#
- brainevent.csrmm = <NameScope(brainevent.csrmm)>#
Product of a CSR sparse matrix and a dense matrix.
Computes
C = A @ B(orC = A.T @ Bwhentranspose=True) whereAis stored in Compressed Sparse Row format andBis a dense matrix.The function supports physical units via
brainunit.- Parameters:
data (
Array|ndarray|Quantity|Number) – Non-zero values of the CSR matrix. Shape(nse,)for heterogeneous weights or(1,)for a single homogeneous weight.indices (
Array|ndarray) – Column indices of the non-zero elements. Shape(nse,)with integer dtype.indptr (
Array|ndarray) – Row index pointer array. Shape(shape[0] + 1,)and same dtype asindices.B (
Array|ndarray|Quantity|Number) – Dense matrix. Shape(shape[0], cols)whentranspose=Trueor(shape[1], cols)whentranspose=False.shape (
Tuple[int,int]) – Two-element tuple(m, k)giving the logical shape of the sparse matrixA.transpose (
bool) – IfTrue, transposeAbefore multiplication. Default isFalse.backend (
str|None) – Compute backend. One of'numba','pallas', orNone(auto-select). Default isNone.
- Returns:
C – Result matrix. Shape
(shape[1], cols)whentranspose=Trueor(shape[0], cols)whentranspose=False.- Return type:
Array|ndarray|Quantity|Number
See also
csrmvCSR matrix–vector multiplication.
binary_csrmmEvent-driven (binary) CSR matrix–matrix multiplication.
Notes
Custom JVP and transpose rules are provided for automatic differentiation with respect to
dataandB.Mathematically, the non-transposed operation computes:
C[i, l] = sum_{j in nz(i)} A[i, j] * B[j, l]where
nz(i)denotes the set of column indices with non-zero entries in rowiof the CSR matrix.When
transpose=True, the transposed operation computes:C[j, l] = sum_{i in nz_col(j)} A[i, j] * B[i, l]where
nz_col(j)denotes the set of row indices with non-zero entries in columnj.For homogeneous weights (
dataof shape(1,)),A[i, j]equals the constantdata[0]for all structural non-zero positions.References
Examples
>>> import jax.numpy as jnp >>> from brainevent._csr.float import csrmm >>> data = jnp.array([1.0, 2.0, 3.0, 4.0]) >>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32) >>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32) >>> B = jnp.array([[1.0, 0.5], ... [2.0, 1.5], ... [3.0, 2.5]]) >>> csrmm(data, indices, indptr, B, shape=(2, 3))