brainevent.fcnmv#
- brainevent.fcnmv = <NameScope(brainevent.fcnmv)>#
Sparse matrix–vector product with fixed connection number.
Computes
y = W @ v(ory = W^T @ vwhentranspose=True) whereWis a sparse weight matrix stored in fixed-connection-number format andvis a dense floating-point vector.- Parameters:
weights (
Array|Quantity) – Non-zero weight values. Shape is(1,)for homogeneous weights or(num_pre, num_conn)for heterogeneous weights. Must have a floating-point dtype.indices (
Array) – Integer index array of shape(num_pre, num_conn)specifying the post-synaptic (column) indices of each connection.vector (
Array|Quantity) – Dense vector to multiply with.shape (
Tuple[int,int]) – Logical(num_pre, num_post)shape of the equivalent dense weight matrix.transpose (
bool) – IfFalse, computeW @ v(fixed post-synaptic connections, gather mode). IfTrue, computeW^T @ v(fixed pre-synaptic connections, scatter mode).
- Returns:
Result vector. Shape is
(num_pre,)whentranspose=Falseor(num_post,)whentranspose=True.- Return type:
Array|Quantity
See also
fcnmmFloat sparse matrix–matrix product with fixed connection number.
binary_fcnmvEvent-driven (binary) variant.
Notes
The sparse weight matrix
Wof shape(num_pre, num_post)is stored in fixed-connection-number format where each rowihas exactlyn_connnon-zero entries at column positionsindices[i, :].When
transpose=False(gather mode), the matrix-vector product computes:y[i] = sum_{k=0}^{n_conn-1} weights[i, k] * v[indices[i, k]]For homogeneous weights (
weightshas shape(1,)):y[i] = w * sum_{k=0}^{n_conn-1} v[indices[i, k]]When
transpose=True(scatter mode), the computation distributes each row’s contributions to the target columns:y[indices[i, k]] += weights[i, k] * v[i]for alli, kThe computational cost is
O(num_pre * n_conn)regardless of the number of post-synaptic neurons, making this efficient for sparse connectivity.Examples
>>> import jax.numpy as jnp >>> from brainevent._fcn.float import fcnmv >>> >>> weights = jnp.array([[0.5, 1.0], [1.5, 2.0]], dtype=jnp.float32) >>> indices = jnp.array([[0, 1], [1, 2]]) >>> vector = jnp.array([1.0, 2.0, 3.0]) >>> y = fcnmv(weights, indices, vector, shape=(2, 3), transpose=False) >>> print(y) [2.5 9. ]