brainevent.binary_densemm#
- brainevent.binary_densemm = <NameScope(brainevent.binary_densemm)>#
Perform event-driven dense matrix-matrix multiplication with binary spikes.
When
transpose=False, computesweights[m, k] @ spikes[k, n] -> out[m, n](dense matrix times binary matrix).When
transpose=True, computesweights[k, m].T @ spikes[k, n] -> out[m, n](transposed dense matrix times binary matrix). Both weights and spikes share their first dimensionk.- Parameters:
weights (array_like) – The weight matrix. Shape
(m, k)whentranspose=False, or(k, m)whentranspose=True. Can be abrainunitquantity.spikes (array_like) – The binary matrix. Shape
(k, n)in both modes. Can be boolean or float. Can be abrainunitquantity.transpose (bool) – If False, compute
weights @ spikes. If True, computeweights.T @ spikes.backend (
str|None) – Backend to use for the computation. One of'numba','pallas', orNone(auto-select).
- Returns:
result – Result matrix with shape
(m, n). If inputs carry units, the result carries the product of the weight and spike units.- Return type:
array_like
- Raises:
AssertionError – If the shared dimensions of
weightsandspikesdo not match.
See also
binary_densemvMatrix-vector variant of binary dense multiplication.
Notes
When
transpose=False, the operation computes:out[i, j] = sum_{k} W[i, k] * s[k, j]where
s[k, j]is1ifspikes[k, j]is active and0otherwise.When
transpose=True, the operation computes:out[i, j] = sum_{k} W[k, i] * s[k, j]where
s[k, j]is1ifspikes[k, j]is active and0otherwise.Boolean spikes are converted to
weights.dtypebefore the primitive call to avoid Pallas Triton boolean buffer corruption.Examples
>>> import jax.numpy as jnp >>> from brainevent._dense.binary import binary_densemm >>> weights = jnp.ones((3, 4), dtype=jnp.float32) >>> spikes = jnp.array([[True, False], ... [False, True], ... [True, True], ... [False, False]]) >>> binary_densemm(weights, spikes, transpose=False) Array([[2., 2.], [2., 2.], [2., 2.]], dtype=float32)