brainevent.binary_densemv

Contents

brainevent.binary_densemv#

brainevent.binary_densemv = <NameScope(brainevent.binary_densemv)>#

Perform event-driven dense matrix-vector multiplication with binary spikes.

When transpose=False, computes weights[m, k] @ spikes[k] -> out[m] (dense matrix times binary vector).

When transpose=True, computes spikes[k] @ weights[k, n] -> out[n] (binary vector times dense matrix).

Parameters:
  • weights (array_like) – The weight matrix. Shape (m, k) when transpose=False, or (k, n) when transpose=True. Can be a brainunit quantity.

  • spikes (array_like) – The binary vector with shape (k,). Can be boolean or float. Can be a brainunit quantity.

  • transpose (bool) – If False, compute weights @ spikes. If True, compute spikes @ weights.

  • backend (str | None) – Backend to use for the computation. One of 'numba', 'pallas', or None (auto-select).

Returns:

result – Result vector. Shape (m,) when transpose=False, or (n,) when transpose=True. If inputs carry units, the result carries the product of the weight and spike units.

Return type:

array_like

Raises:

AssertionError – If the inner dimensions of weights and spikes do not match.

See also

binary_densemm

Matrix-matrix variant of binary dense multiplication.

Notes

The computation is event-driven: only the columns (or rows) of weights corresponding to active (nonzero or True) entries of spikes are accumulated. This is mathematically equivalent to a standard matrix-vector product but can be faster when spikes are sparse.

When transpose=False, the operation computes:

out[i] = sum_{j} W[i, j] * s[j]

where s[j] is 1 if spikes[j] is active and 0 otherwise.

When transpose=True, the operation computes:

out[j] = sum_{i} W[i, j] * s[i]

where s[i] is 1 if spikes[i] is active and 0 otherwise.

For boolean spikes, an entry is considered active when it is True. For float spikes, an entry is considered active when it is > 0.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._dense.binary import binary_densemv
>>> weights = jnp.ones((3, 4), dtype=jnp.float32)
>>> spikes = jnp.array([True, False, True, False])
>>> binary_densemv(weights, spikes, transpose=False)
Array([2., 2., 2.], dtype=float32)