brainevent.jitnmm#
- brainevent.jitnmm = <NameScope(brainevent.jitnmm)>#
JIT normally-distributed matrix-matrix product.
Computes
W @ B(orW.T @ B) whereWis a random matrix with entries drawn fromNormal(w_loc, w_scale)masked by Bernoulli(prob), without materialisingW.- Parameters:
w_loc (
Array|ndarray|Quantity|Number) – Mean of the normal weight distribution.w_scale (
Array|ndarray|Quantity|Number) – Standard deviation. Must share units with w_loc.prob (
float) – Connection probability in[0, 1].B (
Array|ndarray|Quantity|Number) – Right-hand matrix of shape(k, n)wherekequalsshape[0]when transpose isTrue, orshape[1]otherwise.shape (
Tuple[int,int]) – Logical matrix shape(n_pre, n_post).transpose (
bool) – IfTrue, multiply by the transpose. Default isFalse.corder (
bool) – Column-major iteration order. Default isTrue.
- Returns:
Output matrix of shape
(shape[1], n)when transpose isTrue, or(shape[0], n)otherwise.- Return type:
Array|ndarray|Quantity|Number- Raises:
ValueError – If
probis not a scalar, is not finite, or is outside[0, 1].ValueError – If B is not 2-D or its leading dimension does not match the matrix shape.
Notes
The connectivity matrix
Wof shape(m, n)follows the model:W[i, j] = N(w_loc, w_scale) * B_mask[i, j]where
N(w_loc, w_scale)is a normal draw andB_mask[i, j] ~ Bernoulli(prob)is a binary mask, both determined byseed.The matrix-matrix product computes:
Y[i, c] = sum_{j=0}^{n-1} W[i, j] * B[j, c]When
transpose=True, the operation becomesY = W^T @ B:Y[j, c] = sum_{i=0}^{m-1} W[i, j] * B[i, c]The matrix
Wis never materialised; weights are generated and consumed on the fly.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_normal.float import jitnmm >>> B = jnp.ones((50, 10)) >>> Y = jitnmm(0.0, 1.0, 0.1, B, seed=42, shape=(100, 50)) >>> Y.shape (100, 10)