brainevent.jitnmv

Contents

brainevent.jitnmv#

brainevent.jitnmv = <NameScope(brainevent.jitnmv)>#

JIT normally-distributed matrix-vector product.

Computes W @ v (or W.T @ v) where W is a random matrix with entries drawn from Normal(w_loc, w_scale) masked by Bernoulli(prob), without materialising W.

Parameters:
  • w_loc (Array | ndarray | Quantity | Number) – Mean of the normal weight distribution.

  • w_scale (Array | ndarray | Quantity | Number) – Standard deviation. Must share units with w_loc.

  • prob (float) – Connection probability in [0, 1].

  • vector (Array | ndarray | Quantity | Number) – Input vector of shape (k,) where k equals shape[0] when transpose is True, or shape[1] otherwise.

  • seed (int | None) – RNG seed. None generates a random seed.

  • shape (Tuple[int, int]) – Logical matrix shape (n_pre, n_post).

  • transpose (bool) – If True, multiply by the transpose. Default is False.

  • corder (bool) – Column-major iteration order. Default is True.

  • backend (str | None) – Compute backend.

Returns:

Output vector. Shape is (shape[1],) when transpose is True, or (shape[0],) otherwise.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • ValueError – If vector is not 1-D or its length does not match the matrix shape.

See also

jitn

Materialise the full matrix.

jitnmm

Matrix-matrix multiply variant.

Notes

The connectivity matrix W of shape (m, n) follows the model:

W[i, j] = N(w_loc, w_scale) * B[i, j]

where N(w_loc, w_scale) is a normal draw and B[i, j] ~ Bernoulli(prob) is a binary mask, both determined by seed.

The matrix-vector product computes:

y[i] = sum_{j=0}^{n-1} W[i, j] * v[j]

When transpose=True, the operation becomes y = W^T @ v:

y[j] = sum_{i=0}^{m-1} W[i, j] * v[i]

The matrix is never materialised; weights are generated and consumed on the fly, avoiding O(m * n) memory.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_normal.float import jitnmv
>>> v = jnp.ones(50)
>>> y = jitnmv(0.0, 1.0, 0.1, v, seed=42, shape=(100, 50))
>>> y.shape
(100,)