brainevent.binary_csrmm#
- brainevent.binary_csrmm = <NameScope(brainevent.binary_csrmm)>#
Product of a CSR sparse matrix and a dense matrix using event-driven (binary) computation.
Computes
C = A @ B(orC = A.T @ Bwhentranspose=True) whereAis stored in Compressed Sparse Row format andBis a dense matrix whose entries are treated as binary events. Entries ofBthat areTrue(boolean) or positive (float) are the only ones that contribute to the result.The function supports physical units via
brainunit.- Parameters:
data (
Array|ndarray|Quantity|Number) – Non-zero weight values of the CSR matrix. Shape(nse,)for heterogeneous weights or(1,)for a single homogeneous weight.indices (
Array|ndarray) – Column indices of the non-zero elements. Shape(nse,)with integer dtype.indptr (
Array|ndarray) – Row index pointer array. Shape(shape[0] + 1,)and same dtype asindices.B (
Array|ndarray|Quantity|Number) – Dense event matrix. Shape(shape[0], cols)whentranspose=Trueor(shape[1], cols)whentranspose=False. Dtype may be boolean or floating-point.shape (
Tuple[int,int]) – Two-element tuple(m, k)giving the logical shape of the sparse matrixA.transpose (
bool) – IfTrue, transposeAbefore multiplication. Default isFalse.backend (
str|None) – Compute backend. One of'numba','pallas', orNone(auto-select). Default isNone.
- Returns:
C – Result matrix. Shape
(shape[1], cols)whentranspose=Trueor(shape[0], cols)whentranspose=False.- Return type:
Array|ndarray|Quantity|Number
See also
binary_csrmvBinary CSR matrix–vector multiplication.
csrmmStandard (non-event-driven) CSR matrix–matrix multiplication.
Notes
The operation is event-driven: entries of
Bthat are inactive (Falseor<= 0) are skipped. Custom JVP and transpose rules are provided for automatic differentiation.Mathematically, the non-transposed operation computes:
C[i, l] = sum_{j in nz(i)} A[i, j] * e(B[j, l])where
nz(i)denotes the set of column indices with non-zero entries in rowiof the CSR matrix, ande(B[j, l])is the event indicator:e(B[j, l]) = 1 if B[j, l] is True (bool) or B[j, l] > 0 (float)e(B[j, l]) = 0 otherwiseWhen
transpose=True, the transposed operation computes:C[j, l] = sum_{i in nz_col(j)} A[i, j] * e(B[i, l])where
nz_col(j)denotes the set of row indices with non-zero entries in columnj.For homogeneous weights (
dataof shape(1,)),A[i, j]is the constantdata[0]for all non-zero positions.References
Examples
>>> import jax.numpy as jnp >>> from brainevent._csr.binary import binary_csrmm >>> data = jnp.array([1.0, 2.0, 3.0, 4.0]) >>> indices = jnp.array([0, 2, 1, 2], dtype=jnp.int32) >>> indptr = jnp.array([0, 2, 4], dtype=jnp.int32) >>> B = jnp.array([[True, False], ... [False, True], ... [True, True]]) >>> binary_csrmm(data, indices, indptr, B, shape=(2, 3))