brainevent.csrmv_yw2y#
- brainevent.csrmv_yw2y = <NameScope(brainevent.csrmv_yw2y)>#
Element-wise product of a vector and CSR weights, indexed by CSR structure.
For each non-zero entry
jin the CSR matrix at position(row, col), computesout[j] = w[j] * y[row](non-transposed) orout[j] = w[j] * y[col](transposed). The result has the same shape aswandindices(i.e., one value per structural non-zero).This operation is useful for computing per-synapse quantities in neural network models where
yis a neuron-level vector andwcontains per-synapse weights stored in CSR form.The function supports physical units via
brainunit.- Parameters:
y (
Array|ndarray|Quantity|Number) – Dense vector indexed by the CSR structure. Shape(shape[0],)whentranspose=Falseor(shape[1],)whentranspose=True.w (
Array|ndarray|Quantity|Number) – Per-synapse weight values. Shape(nse,), must match the shape ofindices.indices (
Array|ndarray) – Column indices of the CSR matrix. Shape(nse,)with integer dtype.indptr (
Array|ndarray) – Row index pointer array. Shape(shape[0] + 1,)with integer dtype.shape (tuple of int) – Two-element tuple
(m, k)giving the logical shape of the CSR matrix.transpose (
bool) – IfTrue, indexyby column indices instead of row indices. Default isFalse.backend (
str|None) – Compute backend. Default isNone(auto-select).
- Returns:
out – Per-synapse result vector. Shape
(nse,), same asw.- Return type:
Array|ndarray|Quantity|Number
See also
csrmvStandard CSR matrix–vector multiplication.
Notes
This operation is differentiable with respect to both
yandwvia custom JVP rules. The transpose rule is not yet implemented.Mathematically, for each structural non-zero entry
jof the CSR matrix at position(row, col), the output is computed as:out[j] = w[j] * y[row](non-transposed,transpose=False)out[j] = w[j] * y[col](transposed,transpose=True)where
rowis determined by theindptrarray (the row to which thej-th non-zero belongs) andcol = indices[j].This operation is distinct from standard sparse matrix–vector multiplication: it produces one output element per structural non-zero rather than one per matrix row. It is commonly used to compute per-synapse quantities in spiking neural network models.
Examples
>>> import jax.numpy as jnp >>> from brainevent import csrmv_yw2y >>> y = jnp.array([1.0, 2.0]) >>> w = jnp.array([0.5, 0.3, 0.7, 0.1]) >>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32) >>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32) >>> csrmv_yw2y(y, w, indices, indptr, shape=(2, 3))