brainevent.coomv

Contents

brainevent.coomv#

brainevent.coomv = <NameScope(brainevent.coomv)>#

Perform COO sparse matrix-vector multiplication.

Computes the product of a sparse matrix stored in COO (Coordinate) format and a dense vector.

With transpose=False the operation computes:

y[i] = sum_{k} A[i, k] * v[k]

With transpose=True:

y[k] = sum_{i} A[i, k] * v[i]

where A is the sparse matrix defined by (data, row, col).

Parameters:
  • data (Array | ndarray | Quantity | Number) – Non-zero values of the sparse matrix. Either a scalar (shape (1,) for homogeneous weights) or a 1-D array of length nnz (heterogeneous weights).

  • row (Array | ndarray) – 1-D int array of row indices, length nnz.

  • col (Array | ndarray) – 1-D int array of column indices, length nnz.

  • vector (Array | ndarray | Quantity | Number) – Dense input vector. Shape (shape[1],) when transpose=False, or (shape[0],) when transpose=True.

  • shape (Tuple[int, int]) – Logical (m, k) shape of the sparse matrix.

  • transpose (bool) – If True, multiply by A.T. Default is False.

  • backend (str | None) – Compute backend (e.g. 'numba', 'pallas'). None selects the default.

Returns:

Result vector. Shape (shape[0],) when transpose=False, or (shape[1],) when transpose=True. Carries the product of the units of data and vector if applicable.

Return type:

Array | ndarray | Quantity | Number

See also

coomm

COO sparse matrix-matrix multiplication.

binary_coomv

Event-driven (binary) COO matrix-vector multiplication.

Notes

The kernel iterates over all nnz stored elements and, for each triplet (data[s], row[s], col[s]), accumulates data[s] * vector[col[s]] into y[row[s]] (forward) or data[s] * vector[row[s]] into y[col[s]] (transpose).

When data is a scalar the same weight is used for every non-zero, enabling a more compact kernel.

Physical units attached via brainunit are split before the computation and re-applied to the result.

This function supports automatic differentiation (JVP and transpose rules), vmap batching, and multiple hardware backends.

Examples

>>> import jax.numpy as jnp
>>> from brainevent import coomv
>>> data = jnp.array([1.0, 2.0, 3.0])
>>> row = jnp.array([0, 1, 2])
>>> col = jnp.array([1, 0, 2])
>>> v = jnp.array([1.0, 2.0, 3.0])
>>> coomv(data, row, col, v, shape=(3, 3))