brainevent.binary_csrmm

Contents

brainevent.binary_csrmm#

brainevent.binary_csrmm = <NameScope(brainevent.binary_csrmm)>#

Product of a CSR sparse matrix and a dense matrix using event-driven (binary) computation.

Computes C = A @ B (or C = A.T @ B when transpose=True) where A is stored in Compressed Sparse Row format and B is a dense matrix whose entries are treated as binary events. Entries of B that are True (boolean) or positive (float) are the only ones that contribute to the result.

The function supports physical units via brainunit.

Parameters:
  • data (Array | ndarray | Quantity | Number) – Non-zero weight values of the CSR matrix. Shape (nse,) for heterogeneous weights or (1,) for a single homogeneous weight.

  • indices (Array | ndarray) – Column indices of the non-zero elements. Shape (nse,) with integer dtype.

  • indptr (Array | ndarray) – Row index pointer array. Shape (shape[0] + 1,) and same dtype as indices.

  • B (Array | ndarray | Quantity | Number) – Dense event matrix. Shape (shape[0], cols) when transpose=True or (shape[1], cols) when transpose=False. Dtype may be boolean or floating-point.

  • shape (Tuple[int, int]) – Two-element tuple (m, k) giving the logical shape of the sparse matrix A.

  • transpose (bool) – If True, transpose A before multiplication. Default is False.

  • backend (str | None) – Compute backend. One of 'numba', 'pallas', or None (auto-select). Default is None.

Returns:

C – Result matrix. Shape (shape[1], cols) when transpose=True or (shape[0], cols) when transpose=False.

Return type:

Array | ndarray | Quantity | Number

See also

binary_csrmv

Binary CSR matrix–vector multiplication.

csrmm

Standard (non-event-driven) CSR matrix–matrix multiplication.

Notes

The operation is event-driven: entries of B that are inactive (False or <= 0) are skipped. Custom JVP and transpose rules are provided for automatic differentiation.

Mathematically, the non-transposed operation computes:

C[i, l] = sum_{j in nz(i)} A[i, j] * e(B[j, l])

where nz(i) denotes the set of column indices with non-zero entries in row i of the CSR matrix, and e(B[j, l]) is the event indicator:

e(B[j, l]) = 1  if B[j, l] is True (bool) or B[j, l] > 0 (float) e(B[j, l]) = 0  otherwise

When transpose=True, the transposed operation computes:

C[j, l] = sum_{i in nz_col(j)} A[i, j] * e(B[i, l])

where nz_col(j) denotes the set of row indices with non-zero entries in column j.

For homogeneous weights (data of shape (1,)), A[i, j] is the constant data[0] for all non-zero positions.

References

Examples

>>> import jax.numpy as jnp
>>> from brainevent._csr.binary import binary_csrmm
>>> data = jnp.array([1.0, 2.0, 3.0, 4.0])
>>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32)
>>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32)
>>> B = jnp.array([[True, False],
...                [False, True],
...                [True, True]])
>>> binary_csrmm(data, indices, indptr, B, shape=(2, 3))