JITCNormalC#
- class brainevent.JITCNormalC(loc, scale=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#
Just-In-Time Connectivity Normal distribution matrix with Column-oriented representation.
This class represents a column-oriented sparse matrix optimized for JAX-based transformations where non-zero elements follow a normal distribution. It follows the Compressed Sparse Column (CSC) format conceptually, storing location (mean) and scale (standard deviation) parameters for the normal distribution, along with probability and seed information to determine the sparse structure.
The column-oriented structure makes column-based operations more efficient than row-based ones, making this class the transpose-oriented counterpart to JITCNormalR.
- wloc#
The location (mean) parameter of the normal distribution for non-zero elements.
- Type:
Union[jax.Array, u.Quantity]
- wscale#
The scale (standard deviation) parameter of the normal distribution for non-zero elements.
- Type:
Union[jax.Array, u.Quantity]
- shape#
The shape of the matrix as a tuple (rows, cols).
- Type:
MatrixShape
- corder#
Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).
- Type:
- dtype#
The data type of the matrix elements (property inherited from parent).
Examples
>>> import jax >>> import brainunit as u >>> from brainevent import JITCNormalC # Create a normal distribution matrix with mean 1.5, std 0.2, probability 0.1, and seed 42 >>> normal_matrix = JITCNormalC((1.5, 0.2, 0.1, 42), shape=(10, 10)) >>> normal_matrix JITCNormalC(shape=(10, 10), wloc=1.5, wscale=0.2, prob=0.1, seed=42, corder=False) >>> # Perform matrix-vector multiplication >>> vec = jax.numpy.ones(10) >>> result = normal_matrix @ vec >>> # Apply scalar operation >>> scaled = normal_matrix * 2.0 >>> # Convert to dense representation >>> dense_matrix = normal_matrix.todense() >>> # Transpose operation returns a JITCNormalR instance >>> row_matrix = normal_matrix.transpose()
Notes
The mathematical model is the same as
JITCNormalR:W[i, j] = Normal(mu, sigma) * Bernoulli(p)where
mu = wloc,sigma = wscale, andp = prob. The column-oriented representation means thatJITCNormalCis conceptually the transpose of aJITCNormalRmatrix with swapped dimensions.Key properties:
JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
More memory-efficient than dense matrices for sparse connectivity patterns
More efficient than
JITCNormalRfor column-based operationsWell-suited for neural network connectivity matrices with normally distributed weights
The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed
- todense()[source]#
Convert the sparse column-oriented normal-weight matrix to dense format.
Generates a full dense representation where each non-zero entry is drawn from
Normal(wloc, wscale)at positions determined by the probability and seed. The generated dense matrix always hasself.shape.- Parameters:
None
- Returns:
A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a
u.Quantity), the returned array will have the same units.- Return type:
Array|Quantity- Raises:
None –
See also
jitnThe underlying function that materializes the matrix.
Notes
The dense matrix is generated element-wise as:
dense[i, j] = Normal(mu, sigma) if hash(seed, i, j) < p else 0Examples
>>> from brainevent import JITCNormalC >>> sparse_matrix = JITCNormalC((1.5, 0.2, 0.5, 42), shape=(3, 10)) >>> dense_matrix = sparse_matrix.todense() >>> dense_matrix.shape # (3, 10)
- transpose(axes=None)[source]#
Transposes the column-oriented matrix into a row-oriented matrix.
This method returns a row-oriented matrix (
JITCNormalR) with rows and columns swapped, preserving the same weight parameters (wloc,wscale), probability, and seed values. The transpose operation effectively converts between column-oriented and row-oriented sparse matrix formats.- Parameters:
axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.
- Returns:
A new row-oriented normal distribution matrix with transposed dimensions.
- Return type:
- Raises:
AssertionError – If axes is not None, since partial axis transposition is not supported.
See also
JITCNormalRRow-oriented counterpart.
Notes
The transpose preserves the mathematical identity:
JITCNormalC(shape=(m, n)).transpose().todense() == JITCNormalC(shape=(m, n)).todense().TExamples
>>> from brainevent import JITCNormalC >>> >>> # Create a column-oriented matrix >>> col_matrix = JITCNormalC((1.5, 0.2, 0.5, 42), shape=(3, 5)) >>> print(col_matrix.shape) # (3, 5) >>> >>> # Transpose to row-oriented matrix >>> row_matrix = col_matrix.transpose() >>> print(row_matrix.shape) # (5, 3) >>> isinstance(row_matrix, JITCNormalR) # True