brainevent.jitnmm

Contents

brainevent.jitnmm#

brainevent.jitnmm = <NameScope(brainevent.jitnmm)>#

JIT normally-distributed matrix-matrix product.

Computes W @ B (or W.T @ B) where W is a random matrix with entries drawn from Normal(w_loc, w_scale) masked by Bernoulli(prob), without materialising W.

Parameters:
  • w_loc (Array | ndarray | Quantity | Number) – Mean of the normal weight distribution.

  • w_scale (Array | ndarray | Quantity | Number) – Standard deviation. Must share units with w_loc.

  • prob (float) – Connection probability in [0, 1].

  • B (Array | ndarray | Quantity | Number) – Right-hand matrix of shape (k, n) where k equals shape[0] when transpose is True, or shape[1] otherwise.

  • seed (int | None) – RNG seed. None generates a random seed.

  • shape (Tuple[int, int]) – Logical matrix shape (n_pre, n_post).

  • transpose (bool) – If True, multiply by the transpose. Default is False.

  • corder (bool) – Column-major iteration order. Default is True.

  • backend (str | None) – Compute backend.

Returns:

Output matrix of shape (shape[1], n) when transpose is True, or (shape[0], n) otherwise.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • ValueError – If B is not 2-D or its leading dimension does not match the matrix shape.

See also

jitn

Materialise the full matrix.

jitnmv

Matrix-vector multiply variant.

Notes

The connectivity matrix W of shape (m, n) follows the model:

W[i, j] = N(w_loc, w_scale) * B_mask[i, j]

where N(w_loc, w_scale) is a normal draw and B_mask[i, j] ~ Bernoulli(prob) is a binary mask, both determined by seed.

The matrix-matrix product computes:

Y[i, c] = sum_{j=0}^{n-1} W[i, j] * B[j, c]

When transpose=True, the operation becomes Y = W^T @ B:

Y[j, c] = sum_{i=0}^{m-1} W[i, j] * B[i, c]

The matrix W is never materialised; weights are generated and consumed on the fly.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_normal.float import jitnmm
>>> B = jnp.ones((50, 10))
>>> Y = jitnmm(0.0, 1.0, 0.1, B, seed=42, shape=(100, 50))
>>> Y.shape
(100, 10)