brainevent.binary_jitnmm#
- brainevent.binary_jitnmm(w_loc, w_scale, prob, B, seed=None, *, shape, transpose=False, corder=True, backend=None)[source]#
Event-driven matrix-matrix multiplication with a JIT normal-distributed connectivity matrix.
Computes
M @ BwhereMis a sparse matrix whose non-zero entries are drawn from a normal distribution with parametersw_loc(mean) andw_scale(standard deviation), andBis a 2-D binary event matrix. Only positions whereBelements are active (True or > 0) contribute to the output.- Parameters:
w_loc (
Array|ndarray|Quantity|Number) – Location (mean) parameter of the normal distribution for the matrix weights. Scalar or 1-element array, optionally with physical units.w_scale (
Array|ndarray|Quantity|Number) – Scale (standard deviation) parameter of the normal distribution for the matrix weights. Must have the same physical dimension asw_loc.prob (
float) – Connection probability in the range[0, 1]. Controls the sparsity of the generated connectivity matrix.B (
Array|ndarray|Quantity|Number) – The binary event matrix to multiply with, shape(k, n). Elements are treated as events (True/False or >0/<=0).seed (
int|None) – Random seed for reproducible matrix generation. If None, a random seed is generated at compile time.shape (
Tuple[int,int]) – Shape of the implicit connectivity matrix as(rows, cols).transpose (
bool) – If True, computeM.T @ Binstead ofM @ B. Default is False.corder (
bool) – Memory layout order for kernel dispatch. True for C-order (row-major), False for Fortran-order (column-major). Default is True.backend (
str|None) – Compute backend to use ('numba','pallas', or None for automatic selection).
- Returns:
The result matrix of the matrix-matrix product with shape
(shape[0], B.shape[1])(or(shape[1], B.shape[1])if transposed). If the inputs carry physical units, the output will have units equal to the product of the weight units and theBunits.- Return type:
Array|ndarray|Quantity|Number- Raises:
brainunit.DimensionMismatchError – If
w_locandw_scaledo not have the same physical dimension.
See also
binary_jitnmvEvent-driven matrix-vector multiplication variant.
jitnmmFloat (non-event) matrix-matrix multiplication with normal weights.
Notes
The connectivity matrix
Wis never materialized in memory. The pseudo-random structure is regenerated on-the-fly using theseedandprobparameters, matching the PRNG sequence used byjitn.The implicit weight matrix has entries:
W[i, j] = Normal(w_loc, w_scale) * Bernoulli(prob)where
Normal(w_loc, w_scale)is an independent draw for each non-zero position, andBernoulli(prob)is 1 with probabilityproband 0 otherwise.The event-driven matrix-matrix product computes:
Y[i, k] = sum_{j in C(i)} N_ij * spike[j, k]where
C(i) = {j : Bernoulli_ij = 1}is the set of connected pre-synaptic indices for post-synaptic neuroni,N_ij ~ Normal(w_loc, w_scale)is the connection weight, andspike[j, k]is treated as a binary event. Each columnkofBis processed independently.The connection length parameter
clen = 2 / probcontrols the average stride between non-zero entries.This operation supports automatic differentiation (JVP and transpose rules) for
w_loc,w_scale, andB. Batching over theBdimension is supported along axes 0, 1, and 2.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_normal.binary import binary_jitnmm >>> w_loc = jnp.array([1.0]) >>> w_scale = jnp.array([0.1]) >>> B = jnp.array([[True, False], [False, True], [True, True], ... [False, False], [True, False]]) >>> result = binary_jitnmm(w_loc, w_scale, 0.5, B, seed=42, ... shape=(3, 5))