brainevent.binary_jitnmv#
- brainevent.binary_jitnmv = <NameScope(brainevent.binary_jitnmv)>#
Event-driven matrix-vector multiplication with a JIT normal-distributed connectivity matrix.
Computes
M @ vwhereMis a sparse matrix whose non-zero entries are drawn from a normal distribution with parametersw_loc(mean) andw_scale(standard deviation), andvis a binary event vector. Only positions wherevis active (True or > 0) contribute to the output, making this operation event-driven and efficient for sparse neural activity patterns.- Parameters:
w_loc (
Array|ndarray|Quantity|Number) – Location (mean) parameter of the normal distribution for the matrix weights. Scalar or 1-element array, optionally with physical units.w_scale (
Array|ndarray|Quantity|Number) – Scale (standard deviation) parameter of the normal distribution for the matrix weights. Must have the same physical dimension asw_loc.prob (
float) – Connection probability in the range[0, 1]. Controls the sparsity of the generated connectivity matrix.vector (
Array|ndarray|Quantity|Number) – The binary event vector to multiply with. Elements are treated as events (True/False or >0/<=0). Shape must be compatible withshape.seed (
int|None) – Random seed for reproducible matrix generation. If None, a random seed is generated at compile time.shape (
Tuple[int,int]) – Shape of the implicit connectivity matrix as(rows, cols).transpose (
bool) – If True, computeM.T @ vinstead ofM @ v. Default is False.corder (
bool) – Memory layout order for kernel dispatch. True for C-order (row-major), False for Fortran-order (column-major). Default is True.backend (
str|None) – Compute backend to use ('numba','pallas', or None for automatic selection).
- Returns:
The result vector of the matrix-vector product. If the inputs carry physical units, the output will have units equal to the product of the weight units and the vector units.
- Return type:
Array|ndarray|Quantity|Number- Raises:
brainunit.DimensionMismatchError – If
w_locandw_scaledo not have the same physical dimension.
See also
binary_jitnmmEvent-driven matrix-matrix multiplication variant.
jitnmvFloat (non-event) matrix-vector multiplication with normal weights.
Notes
The connectivity matrix
Wis never materialized in memory. Instead, the pseudo-random structure is regenerated on-the-fly using theseedandprobparameters, following the same PRNG sequence asjitnto ensure consistency withtodense().The implicit weight matrix has entries:
W[i, j] = Normal(w_loc, w_scale) * Bernoulli(prob)where
Normal(w_loc, w_scale)is an independent draw for each non-zero position, andBernoulli(prob)is 1 with probabilityproband 0 otherwise. The connection mask is determined by a deterministic hash of(seed, i, j).The event-driven matrix-vector product computes:
y[i] = sum_{j in C(i)} N_ij * spike[j]where
C(i) = {j : Bernoulli_ij = 1}is the set of connected pre-synaptic indices for post-synaptic neuroni,N_ij ~ Normal(w_loc, w_scale)is the connection weight, andspike[j]is treated as a binary event (True/False or >0/<=0). Equivalently:y[i] = sum_{j in C(i) : spike[j]=1} N_ijThe connection length parameter
clen = 2 / probcontrols the average stride between non-zero entries.This operation supports automatic differentiation (JVP and transpose rules) for
w_loc,w_scale, andvector. Batching over the vector dimension is promoted tobinary_jitnmm.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_normal.binary import binary_jitnmv >>> w_loc = jnp.array([1.0]) >>> w_scale = jnp.array([0.1]) >>> events = jnp.array([True, False, True, True, False]) >>> result = binary_jitnmv(w_loc, w_scale, 0.5, events, seed=42, ... shape=(3, 5))