brainevent.csrmv_yw2y

Contents

brainevent.csrmv_yw2y#

brainevent.csrmv_yw2y = <NameScope(brainevent.csrmv_yw2y)>#

Element-wise product of a vector and CSR weights, indexed by CSR structure.

For each non-zero entry j in the CSR matrix at position (row, col), computes out[j] = w[j] * y[row] (non-transposed) or out[j] = w[j] * y[col] (transposed). The result has the same shape as w and indices (i.e., one value per structural non-zero).

This operation is useful for computing per-synapse quantities in neural network models where y is a neuron-level vector and w contains per-synapse weights stored in CSR form.

The function supports physical units via brainunit.

Parameters:
  • y (Array | ndarray | Quantity | Number) – Dense vector indexed by the CSR structure. Shape (shape[0],) when transpose=False or (shape[1],) when transpose=True.

  • w (Array | ndarray | Quantity | Number) – Per-synapse weight values. Shape (nse,), must match the shape of indices.

  • indices (Array | ndarray) – Column indices of the CSR matrix. Shape (nse,) with integer dtype.

  • indptr (Array | ndarray) – Row index pointer array. Shape (shape[0] + 1,) with integer dtype.

  • shape (tuple of int) – Two-element tuple (m, k) giving the logical shape of the CSR matrix.

  • transpose (bool) – If True, index y by column indices instead of row indices. Default is False.

  • backend (str | None) – Compute backend. Default is None (auto-select).

Returns:

out – Per-synapse result vector. Shape (nse,), same as w.

Return type:

Array | ndarray | Quantity | Number

See also

csrmv

Standard CSR matrix–vector multiplication.

Notes

This operation is differentiable with respect to both y and w via custom JVP rules. The transpose rule is not yet implemented.

Mathematically, for each structural non-zero entry j of the CSR matrix at position (row, col), the output is computed as:

out[j] = w[j] * y[row] (non-transposed, transpose=False)

out[j] = w[j] * y[col] (transposed, transpose=True)

where row is determined by the indptr array (the row to which the j-th non-zero belongs) and col = indices[j].

This operation is distinct from standard sparse matrix–vector multiplication: it produces one output element per structural non-zero rather than one per matrix row. It is commonly used to compute per-synapse quantities in spiking neural network models.

Examples

>>> import jax.numpy as jnp
>>> from brainevent import csrmv_yw2y
>>> y = jnp.array([1.0, 2.0])
>>> w = jnp.array([0.5, 0.3, 0.7, 0.1])
>>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32)
>>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32)
>>> csrmv_yw2y(y, w, indices, indptr, shape=(2, 3))