brainevent.binary_jitnmv

Contents

brainevent.binary_jitnmv#

brainevent.binary_jitnmv = <NameScope(brainevent.binary_jitnmv)>#

Event-driven matrix-vector multiplication with a JIT normal-distributed connectivity matrix.

Computes M @ v where M is a sparse matrix whose non-zero entries are drawn from a normal distribution with parameters w_loc (mean) and w_scale (standard deviation), and v is a binary event vector. Only positions where v is active (True or > 0) contribute to the output, making this operation event-driven and efficient for sparse neural activity patterns.

Parameters:
  • w_loc (Array | ndarray | Quantity | Number) – Location (mean) parameter of the normal distribution for the matrix weights. Scalar or 1-element array, optionally with physical units.

  • w_scale (Array | ndarray | Quantity | Number) – Scale (standard deviation) parameter of the normal distribution for the matrix weights. Must have the same physical dimension as w_loc.

  • prob (float) – Connection probability in the range [0, 1]. Controls the sparsity of the generated connectivity matrix.

  • vector (Array | ndarray | Quantity | Number) – The binary event vector to multiply with. Elements are treated as events (True/False or >0/<=0). Shape must be compatible with shape.

  • seed (int | None) – Random seed for reproducible matrix generation. If None, a random seed is generated at compile time.

  • shape (Tuple[int, int]) – Shape of the implicit connectivity matrix as (rows, cols).

  • transpose (bool) – If True, compute M.T @ v instead of M @ v. Default is False.

  • corder (bool) – Memory layout order for kernel dispatch. True for C-order (row-major), False for Fortran-order (column-major). Default is True.

  • backend (str | None) – Compute backend to use ('numba', 'pallas', or None for automatic selection).

Returns:

The result vector of the matrix-vector product. If the inputs carry physical units, the output will have units equal to the product of the weight units and the vector units.

Return type:

Array | ndarray | Quantity | Number

Raises:

brainunit.DimensionMismatchError – If w_loc and w_scale do not have the same physical dimension.

See also

binary_jitnmm

Event-driven matrix-matrix multiplication variant.

jitnmv

Float (non-event) matrix-vector multiplication with normal weights.

Notes

The connectivity matrix W is never materialized in memory. Instead, the pseudo-random structure is regenerated on-the-fly using the seed and prob parameters, following the same PRNG sequence as jitn to ensure consistency with todense().

The implicit weight matrix has entries:

W[i, j] = Normal(w_loc, w_scale) * Bernoulli(prob)

where Normal(w_loc, w_scale) is an independent draw for each non-zero position, and Bernoulli(prob) is 1 with probability prob and 0 otherwise. The connection mask is determined by a deterministic hash of (seed, i, j).

The event-driven matrix-vector product computes:

y[i] = sum_{j in C(i)} N_ij * spike[j]

where C(i) = {j : Bernoulli_ij = 1} is the set of connected pre-synaptic indices for post-synaptic neuron i, N_ij ~ Normal(w_loc, w_scale) is the connection weight, and spike[j] is treated as a binary event (True/False or >0/<=0). Equivalently:

y[i] = sum_{j in C(i) : spike[j]=1} N_ij

The connection length parameter clen = 2 / prob controls the average stride between non-zero entries.

This operation supports automatic differentiation (JVP and transpose rules) for w_loc, w_scale, and vector. Batching over the vector dimension is promoted to binary_jitnmm.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_normal.binary import binary_jitnmv
>>> w_loc = jnp.array([1.0])
>>> w_scale = jnp.array([0.1])
>>> events = jnp.array([True, False, True, True, False])
>>> result = binary_jitnmv(w_loc, w_scale, 0.5, events, seed=42,
...                        shape=(3, 5))