brainevent.csrmv

Contents

brainevent.csrmv#

brainevent.csrmv = <NameScope(brainevent.csrmv)>#

Product of a CSR sparse matrix and a dense vector.

Computes y = A @ v (or y = A.T @ v when transpose=True) where A is stored in Compressed Sparse Row format and v is a dense vector. Unlike the binary (event-driven) variant, every element of v contributes to the result regardless of sign or magnitude.

The function supports physical units via brainunit. If data or v carry units, the result is returned in the corresponding product unit.

Parameters:
  • data (Array | ndarray | Quantity | Number) – Non-zero values of the CSR matrix. Shape (nse,) for heterogeneous weights or (1,) for a single homogeneous weight shared across all connections.

  • indices (Array | ndarray) – Column indices of the non-zero elements. Shape (nse,) with integer dtype (int32, int64, uint32, or uint64).

  • indptr (Array | ndarray) – Row index pointer array. Shape (shape[0] + 1,) and same dtype as indices.

  • v (Array | ndarray | Quantity | Number) – Dense vector. Shape (shape[0],) when transpose=True or (shape[1],) when transpose=False.

  • shape (Tuple[int, int]) – Two-element tuple (m, k) giving the logical shape of the sparse matrix A.

  • transpose (bool) – If True, the sparse matrix is transposed before multiplication, i.e. compute A.T @ v. Default is False.

  • backend (str | None) – Compute backend to use. One of 'numba', 'pallas', or None (auto-select). Default is None.

Returns:

y – Result vector. Shape (shape[1],) when transpose=True or (shape[0],) when transpose=False.

Return type:

Array | ndarray | Quantity | Number

See also

csrmm

CSR matrix–matrix multiplication.

binary_csrmv

Event-driven (binary) CSR matrix–vector multiplication.

Notes

This operation is differentiable with respect to both data and v via custom JVP and transpose rules.

Mathematically, the non-transposed operation computes:

y[i] = sum_{j in nz(i)} A[i, j] * v[j]

where nz(i) denotes the set of column indices with non-zero entries in row i of the CSR matrix.

When transpose=True, the transposed operation computes:

y[j] = sum_{i in nz_col(j)} A[i, j] * v[i]

where nz_col(j) denotes the set of row indices with non-zero entries in column j.

For homogeneous weights (data of shape (1,)), A[i, j] equals the constant data[0] for all structural non-zero positions.

References

Examples

>>> import jax.numpy as jnp
>>> from brainevent._csr.float import csrmv
>>> data = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32)
>>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32)
>>> v = jnp.array([1.0, 2.0, 3.0])
>>> csrmv(data, indices, indptr, v, shape=(2, 3))