JITCScalarR#
- class brainevent.JITCScalarR(weight, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#
Just-In-Time Connectivity Homogeneous matrix with Row-oriented representation.
This class represents a row-oriented homogeneous sparse matrix optimized for JAX-based transformations. It follows the Compressed Sparse Row (CSR) format conceptually, storing a uniform weight value for all non-zero elements in the matrix, along with probability and seed information to determine the sparse structure.
The class is designed for efficient neural network connectivity patterns where weights are homogeneous (identical) but connectivity is sparse and stochastically determined. The row-oriented structure makes row-based operations more efficient than column-based ones.
- weight#
The homogeneous value used for all non-zero elements in the matrix. Can be a plain JAX array or a quantity with units.
- Type:
Union[jax.Array, u.Quantity]
- prob#
Probability for each potential connection. Controls the sparsity level with 0.0 meaning no connections and 1.0 meaning all possible connections.
- Type:
Union[float, jax.Array]
- seed#
Random seed used for initialization of the sparse structure. Using the same seed produces identical connectivity patterns.
- Type:
Union[int, jax.Array]
- shape#
The shape of the matrix as a tuple (rows, cols).
- Type:
MatrixShape
- corder#
Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).
- Type:
- dtype#
The data type of the matrix elements (property inherited from parent).
Examples
>>> import jax >>> import brainunit as u >>> from brainevent import JITCScalarR # Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42 >>> homo_matrix = JITCScalarR((1.5, 0.1, 42), shape=(10, 10)) >>> homo_matrix JITCHomoR(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False) # Create a matrix with units >>> weighted_matrix = JITCScalarR((1.5 * u.mV, 0.1, 42), shape=(10, 10)) >>> weighted_matrix JITCHomoR(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False) # Perform matrix-vector multiplication >>> vec = jax.numpy.ones(10) >>> result = homo_matrix @ vec >>> result.shape # (10,) # Apply scalar operations >>> scaled = homo_matrix * 2.0 >>> scaled.weight # 3.0 # Arithmetic operations maintain the sparse structure >>> neg_matrix = -homo_matrix >>> neg_matrix.weight # -1.5 # Convert to dense representation >>> dense_matrix = homo_matrix.todense() >>> dense_matrix.shape # (10, 10) # Transpose operation returns a column-oriented matrix >>> col_matrix = homo_matrix.transpose() >>> isinstance(col_matrix, JITCScalarC) # True >>> col_matrix.shape # (10, 10)
Notes
The mathematical model for this matrix is:
W[i, j] = w * Bernoulli(p)where
wis the scalar weight (self.weight),pis the connection probability (self.prob), and the Bernoulli draw is determined by a deterministic hash from the seed. The expected value of each element isE[W[i, j]] = w * pand the variance isVar[W[i, j]] = w^2 * p * (1 - p).For a matrix-vector product
y = W @ x:y[i] = sum_{j in C(i)} w * x[j]where
C(i)is the deterministic random connection set for rowi, with|C(i)| ~ Binomial(n_cols, p).Key properties:
JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
More memory-efficient than dense matrices for sparse connectivity patterns
Well-suited for neural network connectivity matrices with uniform weights
Optimized for matrix-vector operations common in neural simulations
The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed
When used with units (e.g.,
u.mV), units are preserved through operations
See also
JITCScalarCColumn-oriented counterpart of this class.
JITCScalarMatrixBase class providing shared functionality.
- todense()[source]#
Convert the sparse scalar-weight matrix to dense format.
Generates a full dense representation of the sparse matrix by materializing all entries
W[i, j] = w * Bernoulli(p)determined by the probability and seed.- Parameters:
None
- Returns:
A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a
u.Quantity), the returned array will have the same units.- Return type:
Array|Quantity- Raises:
None –
See also
jitsThe underlying function that materializes the matrix.
Notes
The dense matrix is generated by iterating over all
(i, j)positions and placing the scalar weightwat each position where the deterministic PRNG indicates a connection:dense[i, j] = w if hash(seed, i, j) < p else 0Examples
>>> import brainunit as u >>> from brainevent import JITCScalarR >>> sparse_matrix = JITCScalarR((1.5 * u.mV, 0.5, 42), shape=(10, 4)) >>> dense_matrix = sparse_matrix.todense() >>> dense_matrix.shape # (10, 4)
- transpose(axes=None)[source]#
Transposes the row-oriented matrix into a column-oriented matrix.
This method returns a column-oriented matrix (JITCScalarC) with rows and columns swapped, preserving the same weight, probability, and seed values. The transpose operation effectively converts between row-oriented and column-oriented sparse matrix formats.
- Parameters:
axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.
- Returns:
A new column-oriented homogeneous matrix with transposed dimensions.
- Return type:
- Raises:
AssertionError – If axes is not None, since partial axis transposition is not supported.
Examples
>>> import jax >>> import brainunit as u >>> from brainevent import JITCScalarR >>> >>> # Create a row-oriented matrix >>> row_matrix = JITCScalarR((1.5, 0.5, 42), shape=(30, 5)) >>> print(row_matrix.shape) # (30, 5) >>> >>> # Transpose to column-oriented matrix >>> col_matrix = row_matrix.transpose() >>> print(col_matrix.shape) # (5, 30) >>> isinstance(col_matrix, JITCScalarC) # True