brainevent.coomv#
- brainevent.coomv = <NameScope(brainevent.coomv)>#
Perform COO sparse matrix-vector multiplication.
Computes the product of a sparse matrix stored in COO (Coordinate) format and a dense vector.
With
transpose=Falsethe operation computes:y[i] = sum_{k} A[i, k] * v[k]With
transpose=True:y[k] = sum_{i} A[i, k] * v[i]where
Ais the sparse matrix defined by (data, row, col).- Parameters:
data (
Array|ndarray|Quantity|Number) – Non-zero values of the sparse matrix. Either a scalar (shape(1,)for homogeneous weights) or a 1-D array of lengthnnz(heterogeneous weights).row (
Array|ndarray) – 1-D int array of row indices, lengthnnz.col (
Array|ndarray) – 1-D int array of column indices, lengthnnz.vector (
Array|ndarray|Quantity|Number) – Dense input vector. Shape(shape[1],)whentranspose=False, or(shape[0],)whentranspose=True.shape (
Tuple[int,int]) – Logical(m, k)shape of the sparse matrix.transpose (
bool) – IfTrue, multiply byA.T. Default isFalse.backend (
str|None) – Compute backend (e.g.'numba','pallas').Noneselects the default.
- Returns:
Result vector. Shape
(shape[0],)whentranspose=False, or(shape[1],)whentranspose=True. Carries the product of the units of data and vector if applicable.- Return type:
Array|ndarray|Quantity|Number
See also
coommCOO sparse matrix-matrix multiplication.
binary_coomvEvent-driven (binary) COO matrix-vector multiplication.
Notes
The kernel iterates over all
nnzstored elements and, for each triplet(data[s], row[s], col[s]), accumulatesdata[s] * vector[col[s]]intoy[row[s]](forward) ordata[s] * vector[row[s]]intoy[col[s]](transpose).When data is a scalar the same weight is used for every non-zero, enabling a more compact kernel.
Physical units attached via
brainunitare split before the computation and re-applied to the result.This function supports automatic differentiation (JVP and transpose rules),
vmapbatching, and multiple hardware backends.Examples
>>> import jax.numpy as jnp >>> from brainevent import coomv >>> data = jnp.array([1.0, 2.0, 3.0]) >>> row = jnp.array([0, 1, 2]) >>> col = jnp.array([1, 0, 2]) >>> v = jnp.array([1.0, 2.0, 3.0]) >>> coomv(data, row, col, v, shape=(3, 3))