brainevent.binary_densemv#
- brainevent.binary_densemv = <NameScope(brainevent.binary_densemv)>#
Perform event-driven dense matrix-vector multiplication with binary spikes.
When
transpose=False, computesweights[m, k] @ spikes[k] -> out[m](dense matrix times binary vector).When
transpose=True, computesspikes[k] @ weights[k, n] -> out[n](binary vector times dense matrix).- Parameters:
weights (array_like) – The weight matrix. Shape
(m, k)whentranspose=False, or(k, n)whentranspose=True. Can be abrainunitquantity.spikes (array_like) – The binary vector with shape
(k,). Can be boolean or float. Can be abrainunitquantity.transpose (bool) – If False, compute
weights @ spikes. If True, computespikes @ weights.backend (
str|None) – Backend to use for the computation. One of'numba','pallas', orNone(auto-select).
- Returns:
result – Result vector. Shape
(m,)whentranspose=False, or(n,)whentranspose=True. If inputs carry units, the result carries the product of the weight and spike units.- Return type:
array_like
- Raises:
AssertionError – If the inner dimensions of
weightsandspikesdo not match.
See also
binary_densemmMatrix-matrix variant of binary dense multiplication.
Notes
The computation is event-driven: only the columns (or rows) of
weightscorresponding to active (nonzero or True) entries ofspikesare accumulated. This is mathematically equivalent to a standard matrix-vector product but can be faster when spikes are sparse.When
transpose=False, the operation computes:out[i] = sum_{j} W[i, j] * s[j]where
s[j]is1ifspikes[j]is active and0otherwise.When
transpose=True, the operation computes:out[j] = sum_{i} W[i, j] * s[i]where
s[i]is1ifspikes[i]is active and0otherwise.For boolean spikes, an entry is considered active when it is
True. For float spikes, an entry is considered active when it is> 0.Examples
>>> import jax.numpy as jnp >>> from brainevent._dense.binary import binary_densemv >>> weights = jnp.ones((3, 4), dtype=jnp.float32) >>> spikes = jnp.array([True, False, True, False]) >>> binary_densemv(weights, spikes, transpose=False) Array([2., 2., 2.], dtype=float32)