brainevent.binary_jitnmm

Contents

brainevent.binary_jitnmm#

brainevent.binary_jitnmm(w_loc, w_scale, prob, B, seed=None, *, shape, transpose=False, corder=True, backend=None)[source]#

Event-driven matrix-matrix multiplication with a JIT normal-distributed connectivity matrix.

Computes M @ B where M is a sparse matrix whose non-zero entries are drawn from a normal distribution with parameters w_loc (mean) and w_scale (standard deviation), and B is a 2-D binary event matrix. Only positions where B elements are active (True or > 0) contribute to the output.

Parameters:
  • w_loc (Array | ndarray | Quantity | Number) – Location (mean) parameter of the normal distribution for the matrix weights. Scalar or 1-element array, optionally with physical units.

  • w_scale (Array | ndarray | Quantity | Number) – Scale (standard deviation) parameter of the normal distribution for the matrix weights. Must have the same physical dimension as w_loc.

  • prob (float) – Connection probability in the range [0, 1]. Controls the sparsity of the generated connectivity matrix.

  • B (Array | ndarray | Quantity | Number) – The binary event matrix to multiply with, shape (k, n). Elements are treated as events (True/False or >0/<=0).

  • seed (int | None) – Random seed for reproducible matrix generation. If None, a random seed is generated at compile time.

  • shape (Tuple[int, int]) – Shape of the implicit connectivity matrix as (rows, cols).

  • transpose (bool) – If True, compute M.T @ B instead of M @ B. Default is False.

  • corder (bool) – Memory layout order for kernel dispatch. True for C-order (row-major), False for Fortran-order (column-major). Default is True.

  • backend (str | None) – Compute backend to use ('numba', 'pallas', or None for automatic selection).

Returns:

The result matrix of the matrix-matrix product with shape (shape[0], B.shape[1]) (or (shape[1], B.shape[1]) if transposed). If the inputs carry physical units, the output will have units equal to the product of the weight units and the B units.

Return type:

Array | ndarray | Quantity | Number

Raises:

brainunit.DimensionMismatchError – If w_loc and w_scale do not have the same physical dimension.

See also

binary_jitnmv

Event-driven matrix-vector multiplication variant.

jitnmm

Float (non-event) matrix-matrix multiplication with normal weights.

Notes

The connectivity matrix W is never materialized in memory. The pseudo-random structure is regenerated on-the-fly using the seed and prob parameters, matching the PRNG sequence used by jitn.

The implicit weight matrix has entries:

W[i, j] = Normal(w_loc, w_scale) * Bernoulli(prob)

where Normal(w_loc, w_scale) is an independent draw for each non-zero position, and Bernoulli(prob) is 1 with probability prob and 0 otherwise.

The event-driven matrix-matrix product computes:

Y[i, k] = sum_{j in C(i)} N_ij * spike[j, k]

where C(i) = {j : Bernoulli_ij = 1} is the set of connected pre-synaptic indices for post-synaptic neuron i, N_ij ~ Normal(w_loc, w_scale) is the connection weight, and spike[j, k] is treated as a binary event. Each column k of B is processed independently.

The connection length parameter clen = 2 / prob controls the average stride between non-zero entries.

This operation supports automatic differentiation (JVP and transpose rules) for w_loc, w_scale, and B. Batching over the B dimension is supported along axes 0, 1, and 2.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_normal.binary import binary_jitnmm
>>> w_loc = jnp.array([1.0])
>>> w_scale = jnp.array([0.1])
>>> B = jnp.array([[True, False], [False, True], [True, True],
...                [False, False], [True, False]])
>>> result = binary_jitnmm(w_loc, w_scale, 0.5, B, seed=42,
...                        shape=(3, 5))