brainevent.csrmv#
- brainevent.csrmv = <NameScope(brainevent.csrmv)>#
Product of a CSR sparse matrix and a dense vector.
Computes
y = A @ v(ory = A.T @ vwhentranspose=True) whereAis stored in Compressed Sparse Row format andvis a dense vector. Unlike the binary (event-driven) variant, every element ofvcontributes to the result regardless of sign or magnitude.The function supports physical units via
brainunit. Ifdataorvcarry units, the result is returned in the corresponding product unit.- Parameters:
data (
Array|ndarray|Quantity|Number) – Non-zero values of the CSR matrix. Shape(nse,)for heterogeneous weights or(1,)for a single homogeneous weight shared across all connections.indices (
Array|ndarray) – Column indices of the non-zero elements. Shape(nse,)with integer dtype (int32,int64,uint32, oruint64).indptr (
Array|ndarray) – Row index pointer array. Shape(shape[0] + 1,)and same dtype asindices.v (
Array|ndarray|Quantity|Number) – Dense vector. Shape(shape[0],)whentranspose=Trueor(shape[1],)whentranspose=False.shape (
Tuple[int,int]) – Two-element tuple(m, k)giving the logical shape of the sparse matrixA.transpose (
bool) – IfTrue, the sparse matrix is transposed before multiplication, i.e. computeA.T @ v. Default isFalse.backend (
str|None) – Compute backend to use. One of'numba','pallas', orNone(auto-select). Default isNone.
- Returns:
y – Result vector. Shape
(shape[1],)whentranspose=Trueor(shape[0],)whentranspose=False.- Return type:
Array|ndarray|Quantity|Number
See also
csrmmCSR matrix–matrix multiplication.
binary_csrmvEvent-driven (binary) CSR matrix–vector multiplication.
Notes
This operation is differentiable with respect to both
dataandvvia custom JVP and transpose rules.Mathematically, the non-transposed operation computes:
y[i] = sum_{j in nz(i)} A[i, j] * v[j]where
nz(i)denotes the set of column indices with non-zero entries in rowiof the CSR matrix.When
transpose=True, the transposed operation computes:y[j] = sum_{i in nz_col(j)} A[i, j] * v[i]where
nz_col(j)denotes the set of row indices with non-zero entries in columnj.For homogeneous weights (
dataof shape(1,)),A[i, j]equals the constantdata[0]for all structural non-zero positions.References
Examples
>>> import jax.numpy as jnp >>> from brainevent._csr.float import csrmv >>> data = jnp.array([1.0, 2.0, 3.0, 4.0]) >>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32) >>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32) >>> v = jnp.array([1.0, 2.0, 3.0]) >>> csrmv(data, indices, indptr, v, shape=(2, 3))