brainevent.fcnmm

Contents

brainevent.fcnmm#

brainevent.fcnmm = <NameScope(brainevent.fcnmm)>#

Sparse matrix–matrix product with fixed connection number.

Computes Y = W @ M (or Y = W^T @ M when transpose=True) where W is a sparse weight matrix stored in fixed-connection-number format and M is a dense floating-point matrix.

Parameters:
  • weights (Array | Quantity) – Non-zero weight values. Shape is (1,) for homogeneous weights or (num_pre, num_conn) for heterogeneous weights. Must have a floating-point dtype.

  • indices (Array) – Integer index array of shape (num_pre, num_conn) specifying the post-synaptic (column) indices of each connection.

  • matrix (Array | Quantity) – Dense matrix to multiply with, of shape (k, n) where k matches the appropriate sparse-matrix dimension.

  • shape (Tuple[int, int]) – Logical (num_pre, num_post) shape of the equivalent dense weight matrix.

  • transpose (bool) – If False, compute W @ M (fixed post-synaptic connections, gather mode). If True, compute W^T @ M (fixed pre-synaptic connections, scatter mode).

Returns:

Result matrix of shape (num_pre, n) when transpose=False or (num_post, n) when transpose=True.

Return type:

Array | Quantity

See also

fcnmv

Float sparse matrix–vector product with fixed connection number.

binary_fcnmm

Event-driven (binary) variant.

Notes

The sparse weight matrix W of shape (num_pre, num_post) is stored in fixed-connection-number format where each row i has exactly n_conn non-zero entries at column positions indices[i, :].

When transpose=False (gather mode), each output element is:

Y[i, j] = sum_{k=0}^{n_conn-1} weights[i, k] * M[indices[i, k], j]

For homogeneous weights (weights has shape (1,)):

Y[i, j] = w * sum_{k=0}^{n_conn-1} M[indices[i, k], j]

When transpose=True (scatter mode), the computation distributes contributions to target rows:

Y[indices[i, k], j] += weights[i, k] * M[i, j] for all i, k, j

The computational cost is O(num_pre * n_conn * n) where n is the number of columns in M.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._fcn.float import fcnmm
>>>
>>> weights = jnp.ones(1, dtype=jnp.float32)  # homogeneous
>>> indices = jnp.array([[0, 1], [1, 2]])
>>> matrix = jnp.array([[1.0, 0.5],
...                     [2.0, 1.0],
...                     [3.0, 1.5]])
>>> y = fcnmm(weights, indices, matrix, shape=(2, 3), transpose=False)
>>> print(y)
[[3.  1.5]
 [5.  2.5]]