brainevent.fcnmm#
- brainevent.fcnmm = <NameScope(brainevent.fcnmm)>#
Sparse matrix–matrix product with fixed connection number.
Computes
Y = W @ M(orY = W^T @ Mwhentranspose=True) whereWis a sparse weight matrix stored in fixed-connection-number format andMis a dense floating-point matrix.- Parameters:
weights (
Array|Quantity) – Non-zero weight values. Shape is(1,)for homogeneous weights or(num_pre, num_conn)for heterogeneous weights. Must have a floating-point dtype.indices (
Array) – Integer index array of shape(num_pre, num_conn)specifying the post-synaptic (column) indices of each connection.matrix (
Array|Quantity) – Dense matrix to multiply with, of shape(k, n)wherekmatches the appropriate sparse-matrix dimension.shape (
Tuple[int,int]) – Logical(num_pre, num_post)shape of the equivalent dense weight matrix.transpose (
bool) – IfFalse, computeW @ M(fixed post-synaptic connections, gather mode). IfTrue, computeW^T @ M(fixed pre-synaptic connections, scatter mode).
- Returns:
Result matrix of shape
(num_pre, n)whentranspose=Falseor(num_post, n)whentranspose=True.- Return type:
Array|Quantity
See also
fcnmvFloat sparse matrix–vector product with fixed connection number.
binary_fcnmmEvent-driven (binary) variant.
Notes
The sparse weight matrix
Wof shape(num_pre, num_post)is stored in fixed-connection-number format where each rowihas exactlyn_connnon-zero entries at column positionsindices[i, :].When
transpose=False(gather mode), each output element is:Y[i, j] = sum_{k=0}^{n_conn-1} weights[i, k] * M[indices[i, k], j]For homogeneous weights (
weightshas shape(1,)):Y[i, j] = w * sum_{k=0}^{n_conn-1} M[indices[i, k], j]When
transpose=True(scatter mode), the computation distributes contributions to target rows:Y[indices[i, k], j] += weights[i, k] * M[i, j]for alli, k, jThe computational cost is
O(num_pre * n_conn * n)wherenis the number of columns inM.Examples
>>> import jax.numpy as jnp >>> from brainevent._fcn.float import fcnmm >>> >>> weights = jnp.ones(1, dtype=jnp.float32) # homogeneous >>> indices = jnp.array([[0, 1], [1, 2]]) >>> matrix = jnp.array([[1.0, 0.5], ... [2.0, 1.0], ... [3.0, 1.5]]) >>> y = fcnmm(weights, indices, matrix, shape=(2, 3), transpose=False) >>> print(y) [[3. 1.5] [5. 2.5]]