brainevent.fcnmv

Contents

brainevent.fcnmv#

brainevent.fcnmv = <NameScope(brainevent.fcnmv)>#

Sparse matrix–vector product with fixed connection number.

Computes y = W @ v (or y = W^T @ v when transpose=True) where W is a sparse weight matrix stored in fixed-connection-number format and v is a dense floating-point vector.

Parameters:
  • weights (Array | Quantity) – Non-zero weight values. Shape is (1,) for homogeneous weights or (num_pre, num_conn) for heterogeneous weights. Must have a floating-point dtype.

  • indices (Array) – Integer index array of shape (num_pre, num_conn) specifying the post-synaptic (column) indices of each connection.

  • vector (Array | Quantity) – Dense vector to multiply with.

  • shape (Tuple[int, int]) – Logical (num_pre, num_post) shape of the equivalent dense weight matrix.

  • transpose (bool) – If False, compute W @ v (fixed post-synaptic connections, gather mode). If True, compute W^T @ v (fixed pre-synaptic connections, scatter mode).

Returns:

Result vector. Shape is (num_pre,) when transpose=False or (num_post,) when transpose=True.

Return type:

Array | Quantity

See also

fcnmm

Float sparse matrix–matrix product with fixed connection number.

binary_fcnmv

Event-driven (binary) variant.

Notes

The sparse weight matrix W of shape (num_pre, num_post) is stored in fixed-connection-number format where each row i has exactly n_conn non-zero entries at column positions indices[i, :].

When transpose=False (gather mode), the matrix-vector product computes:

y[i] = sum_{k=0}^{n_conn-1} weights[i, k] * v[indices[i, k]]

For homogeneous weights (weights has shape (1,)):

y[i] = w * sum_{k=0}^{n_conn-1} v[indices[i, k]]

When transpose=True (scatter mode), the computation distributes each row’s contributions to the target columns:

y[indices[i, k]] += weights[i, k] * v[i] for all i, k

The computational cost is O(num_pre * n_conn) regardless of the number of post-synaptic neurons, making this efficient for sparse connectivity.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._fcn.float import fcnmv
>>>
>>> weights = jnp.array([[0.5, 1.0], [1.5, 2.0]], dtype=jnp.float32)
>>> indices = jnp.array([[0, 1], [1, 2]])
>>> vector = jnp.array([1.0, 2.0, 3.0])
>>> y = fcnmv(weights, indices, vector, shape=(2, 3), transpose=False)
>>> print(y)
[2.5 9. ]