JITCNormalR#
- class brainevent.JITCNormalR(loc, scale=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#
Just-In-Time Connectivity Normal distribution matrix with Row-oriented representation.
This class represents a row-oriented sparse matrix optimized for JAX-based transformations where non-zero elements follow a normal distribution. It follows the Compressed Sparse Row (CSR) format conceptually, storing location (mean) and scale (standard deviation) parameters for the normal distribution, along with probability and seed information to determine the sparse structure.
The class is designed for efficient neural network connectivity patterns where weights follow a normal distribution but connectivity is sparse and stochastic.
- wloc#
The location (mean) parameter of the normal distribution for non-zero elements.
- Type:
Union[jax.Array, u.Quantity]
- wscale#
The scale (standard deviation) parameter of the normal distribution for non-zero elements.
- Type:
Union[jax.Array, u.Quantity]
- shape#
The shape of the matrix as a tuple (rows, cols).
- Type:
MatrixShape
- corder#
Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).
- Type:
- dtype#
The data type of the matrix elements (property inherited from parent).
Examples
>>> import jax >>> import brainunit as u >>> from brainevent import JITCNormalR # Create a normal distribution matrix with mean 1.5, std 0.2, probability 0.1, and seed 42 >>> normal_matrix = JITCNormalR((1.5, 0.2, 0.1, 42), shape=(10, 10)) >>> normal_matrix JITCNormalR(shape=(10, 10), wloc=1.5, wscale=0.2, prob=0.1, seed=42, corder=False) >>> # Perform matrix-vector multiplication >>> vec = jax.numpy.ones(10) >>> result = normal_matrix @ vec >>> # Apply scalar operation >>> scaled = normal_matrix * 2.0 >>> # Convert to dense representation >>> dense_matrix = normal_matrix.todense() >>> # Transpose operation returns a JITCNormalC instance >>> col_matrix = normal_matrix.transpose()
Notes
The mathematical model for this matrix is:
W[i, j] = Normal(mu, sigma) * Bernoulli(p)where
muiswloc,sigmaiswscale,pisprob, and the Bernoulli and Normal draws are both determined by the seed. For a matrix-vector producty = W @ x:y[i] = sum_{j in C(i)} N_ij * x[j]where
C(i)is the deterministic random connection set for rowiandN_ij ~ Normal(mu, sigma)is the weight for that connection.Key properties:
JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
More memory-efficient than dense matrices for sparse connectivity patterns
Well-suited for neural network connectivity matrices with normally distributed weights
Optimized for matrix-vector operations common in neural simulations
The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed
- todense()[source]#
Convert the sparse normal-weight matrix to dense format.
Generates a full dense representation where each non-zero entry is drawn from
Normal(wloc, wscale)at positions determined by the probability and seed.- Parameters:
None
- Returns:
A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a
u.Quantity), the returned array will have the same units.- Return type:
Array|Quantity- Raises:
None –
See also
jitnThe underlying function that materializes the matrix.
Notes
The dense matrix is generated element-wise as:
dense[i, j] = Normal(mu, sigma) if hash(seed, i, j) < p else 0where
mu = wloc,sigma = wscale, andp = prob.Examples
>>> import brainunit as u >>> from brainevent import JITCNormalR >>> sparse_matrix = JITCNormalR((1.5, 0.2, 0.5, 42), shape=(10, 4)) >>> dense_matrix = sparse_matrix.todense() >>> dense_matrix.shape # (10, 4)
- transpose(axes=None)[source]#
Transpose the row-oriented matrix into a column-oriented matrix.
Returns a column-oriented matrix (
JITCNormalC) with rows and columns swapped, preserving the same weight parameters, probability, and seed.- Parameters:
axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.
- Returns:
A new column-oriented normal-weight matrix with transposed dimensions.
- Return type:
- Raises:
AssertionError – If axes is not None, since partial axis transposition is not supported.
See also
JITCNormalC.transposeThe inverse operation.
Notes
The transpose swaps the
shapeand inverts thecorderflag so that the same PRNG sequence is used, ensuringmat.transpose().todense()equalsmat.todense().T.Examples
>>> from brainevent import JITCNormalR >>> row_matrix = JITCNormalR((1.5, 0.2, 0.5, 42), shape=(30, 5)) >>> col_matrix = row_matrix.transpose() >>> col_matrix.shape # (5, 30) >>> isinstance(col_matrix, JITCNormalC) # True