brainevent.binary_densemm

Contents

brainevent.binary_densemm#

brainevent.binary_densemm = <NameScope(brainevent.binary_densemm)>#

Perform event-driven dense matrix-matrix multiplication with binary spikes.

When transpose=False, computes weights[m, k] @ spikes[k, n] -> out[m, n] (dense matrix times binary matrix).

When transpose=True, computes weights[k, m].T @ spikes[k, n] -> out[m, n] (transposed dense matrix times binary matrix). Both weights and spikes share their first dimension k.

Parameters:
  • weights (array_like) – The weight matrix. Shape (m, k) when transpose=False, or (k, m) when transpose=True. Can be a brainunit quantity.

  • spikes (array_like) – The binary matrix. Shape (k, n) in both modes. Can be boolean or float. Can be a brainunit quantity.

  • transpose (bool) – If False, compute weights @ spikes. If True, compute weights.T @ spikes.

  • backend (str | None) – Backend to use for the computation. One of 'numba', 'pallas', or None (auto-select).

Returns:

result – Result matrix with shape (m, n). If inputs carry units, the result carries the product of the weight and spike units.

Return type:

array_like

Raises:

AssertionError – If the shared dimensions of weights and spikes do not match.

See also

binary_densemv

Matrix-vector variant of binary dense multiplication.

Notes

When transpose=False, the operation computes:

out[i, j] = sum_{k} W[i, k] * s[k, j]

where s[k, j] is 1 if spikes[k, j] is active and 0 otherwise.

When transpose=True, the operation computes:

out[i, j] = sum_{k} W[k, i] * s[k, j]

where s[k, j] is 1 if spikes[k, j] is active and 0 otherwise.

Boolean spikes are converted to weights.dtype before the primitive call to avoid Pallas Triton boolean buffer corruption.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._dense.binary import binary_densemm
>>> weights = jnp.ones((3, 4), dtype=jnp.float32)
>>> spikes = jnp.array([[True, False],
...                     [False, True],
...                     [True, True],
...                     [False, False]])
>>> binary_densemm(weights, spikes, transpose=False)
Array([[2., 2.],
       [2., 2.],
       [2., 2.]], dtype=float32)