brainevent.jitn

Contents

brainevent.jitn#

brainevent.jitn = <NameScope(brainevent.jitn)>#

Materialise a JIT normally-distributed random connectivity matrix.

Generates a dense matrix of shape shape (or its transpose) where each element is drawn from Normal(w_loc, w_scale) with independent Bernoulli masking at probability prob.

Parameters:
  • w_loc (Array | ndarray | Quantity | Number) – Mean of the normal weight distribution.

  • w_scale (Array | ndarray | Quantity | Number) – Standard deviation of the normal weight distribution. Must have the same physical unit as w_loc.

  • prob (float) – Connection probability in [0, 1].

  • seed (int) – RNG seed for reproducible connectivity.

  • shape (Tuple[int, int]) – Logical matrix shape (n_pre, n_post).

  • transpose (bool) – If True, return the transposed matrix of shape (n_post, n_pre). Default is False.

  • corder (bool) – If True (default), iterate in column-major order internally.

  • backend (str | None) – Compute backend (e.g. 'numba', 'pallas').

Returns:

Dense matrix of shape shape (or shape[::-1] when transpose is True).

Return type:

Array | ndarray | Quantity | Number

Raises:

ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

See also

jitnmv

Matrix-vector multiply without materialising the matrix.

jitnmm

Matrix-matrix multiply without materialising the matrix.

jits

Scalar-weight variant (all non-zeros share one weight).

jitu

Uniform-weight variant.

Notes

Each entry W[i, j] of the generated matrix follows the model:

W[i, j] = N(w_loc, w_scale) * B[i, j]

where N(w_loc, w_scale) is a draw from a normal distribution and B[i, j] ~ Bernoulli(prob) is a binary mask. Equivalently:

  • W[i, j] ~ Normal(w_loc, w_scale) with probability prob

  • W[i, j] = 0 with probability 1 - prob

The expected value of each entry is E[W[i, j]] = prob * w_loc.

The connectivity pattern and normal variates are fully determined by seed and prob. Using the same seed always produces the same matrix.

This function materialises the full dense matrix. For implicit (non-materialised) products, use jitnmv() or jitnmm().

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_normal.float import jitn
>>> W = jitn(0.0, 1.0, prob=0.1, seed=42, shape=(100, 50))
>>> W.shape
(100, 50)