brainevent.jitnmv#
- brainevent.jitnmv = <NameScope(brainevent.jitnmv)>#
JIT normally-distributed matrix-vector product.
Computes
W @ v(orW.T @ v) whereWis a random matrix with entries drawn fromNormal(w_loc, w_scale)masked by Bernoulli(prob), without materialisingW.- Parameters:
w_loc (
Array|ndarray|Quantity|Number) – Mean of the normal weight distribution.w_scale (
Array|ndarray|Quantity|Number) – Standard deviation. Must share units with w_loc.prob (
float) – Connection probability in[0, 1].vector (
Array|ndarray|Quantity|Number) – Input vector of shape(k,)wherekequalsshape[0]when transpose isTrue, orshape[1]otherwise.shape (
Tuple[int,int]) – Logical matrix shape(n_pre, n_post).transpose (
bool) – IfTrue, multiply by the transpose. Default isFalse.corder (
bool) – Column-major iteration order. Default isTrue.
- Returns:
Output vector. Shape is
(shape[1],)when transpose isTrue, or(shape[0],)otherwise.- Return type:
Array|ndarray|Quantity|Number- Raises:
ValueError – If
probis not a scalar, is not finite, or is outside[0, 1].ValueError – If vector is not 1-D or its length does not match the matrix shape.
Notes
The connectivity matrix
Wof shape(m, n)follows the model:W[i, j] = N(w_loc, w_scale) * B[i, j]where
N(w_loc, w_scale)is a normal draw andB[i, j] ~ Bernoulli(prob)is a binary mask, both determined byseed.The matrix-vector product computes:
y[i] = sum_{j=0}^{n-1} W[i, j] * v[j]When
transpose=True, the operation becomesy = W^T @ v:y[j] = sum_{i=0}^{m-1} W[i, j] * v[i]The matrix is never materialised; weights are generated and consumed on the fly, avoiding
O(m * n)memory.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_normal.float import jitnmv >>> v = jnp.ones(50) >>> y = jitnmv(0.0, 1.0, 0.1, v, seed=42, shape=(100, 50)) >>> y.shape (100,)