JITCUniformR#
- class brainevent.JITCUniformR(low, high=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#
Just-In-Time Connectivity matrix with Row-oriented representation for uniform weight distributions.
This class implements a row-oriented sparse matrix optimized for JAX-based transformations, following the Compressed Sparse Row (CSR) format conceptually. Instead of storing all non-zero elements explicitly, it uses a uniform distribution with lower and upper bounds (wlow, whigh) to generate weights for connections, along with probability and seed information to determine the sparse structure.
The class is designed for efficient neural network connectivity patterns where weights follow a uniform distribution but connectivity is sparse and stochastic. The actual sparse structure and uniform weight values are generated just-in-time during operations.
- wlow#
The lower bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.
- Type:
Union[jax.Array, u.Quantity]
- whigh#
The upper bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.
- Type:
Union[jax.Array, u.Quantity]
- prob#
Connection probability determining the sparsity of the matrix. Values range from 0 (no connections) to 1 (fully connected).
- Type:
Union[float, jax.Array]
- seed#
Random seed controlling the specific pattern of connections. Using the same seed produces identical connectivity patterns.
- Type:
Union[int, jax.Array]
- shape#
Tuple specifying the dimensions of the matrix as (rows, columns).
- Type:
MatrixShape
- corder#
Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).
- Type:
- dtype#
The data type of the matrix elements (property inherited from parent).
Examples
>>> import jax >>> import brainunit as u >>> from brainevent import JITCUniformR # Create a uniform matrix with bounds [0.1, 0.5], probability 0.2, and seed 42 >>> uniform_matrix = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(10, 10)) >>> uniform_matrix JITCUniformR(shape=(10, 10), wlow=0.1, whigh=0.5, prob=0.2, seed=42, corder=False) # Create a uniform matrix with units >>> uniform_matrix_mv = JITCUniformR((0.1 * u.mV, 0.5 * u.mV, 0.2, 42), shape=(10, 10)) # Perform matrix-vector multiplication >>> vec = jax.numpy.ones(10) >>> result = uniform_matrix @ vec >>> # Each element in result is a weighted sum using uniformly distributed weights # Apply scalar operation (scales both lower and upper bounds) >>> scaled = uniform_matrix * 2.0 >>> print(scaled.wlow, scaled.whigh) # 0.2 1.0 # Convert to dense representation >>> dense_matrix = uniform_matrix.todense() >>> # dense_matrix has shape (10, 10) with ~20% non-zero elements >>> # each non-zero element is uniformly distributed between 0.1 and 0.5 # Transpose operation returns a JITCUniformC instance >>> col_matrix = uniform_matrix.transpose() >>> isinstance(col_matrix, JITCUniformC) # True # Update bounds while preserving connectivity pattern >>> updated = uniform_matrix.with_data(0.2, 0.8) >>> print(updated.wlow, updated.whigh) # 0.2 0.8 # Use with JAX transformations >>> @jax.jit ... def matrix_vector_product(mat, vec): ... return mat @ vec >>> result_jit = matrix_vector_product(uniform_matrix, vec)
Notes
The mathematical model for
JITCUniformRis:W[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)Each entry
W[i, j]is independently drawn from the continuous uniform distribution on[w_low, w_high]with probabilityprob, and zero otherwise. More precisely, the entry is computed as:W[i, j] = U[i, j] * B[i, j]where
U[i, j] ~ Uniform(w_low, w_high)andB[i, j] ~ Bernoulli(prob)are independent random variables, both determined byseed.The row-oriented representation means that the random number generator state is seeded per-row (or per-column, depending on
corder), making row-based operations (W @ v) the natural direction.Key properties:
JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)
More memory-efficient than dense matrices for sparse connectivity patterns
Well-suited for neural network connectivity matrices with uniformly distributed weights
Optimized for matrix-vector operations common in neural simulations
The actual matrix elements are never explicitly stored, only generated during operations
Using the same seed always produces the same random connectivity pattern and weights
- todense()[source]#
Convert the sparse uniform matrix to a dense array.
Generates a full dense representation of the sparse matrix by sampling
Uniform(w_low, w_high)values for all connections determined by the probability and seed. The resulting dense matrix preserves all the numerical properties of the sparse representation.- Returns:
A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a
u.Quantity), the returned array will have the same units.- Return type:
Array|Quantity- Raises:
None – This method does not raise exceptions under normal use.
See also
JITCUniformC.todenseColumn-oriented variant.
jituStandalone function to materialize JIT uniform matrices.
Notes
The dense matrix is generated according to:
dense[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)for each
(i, j)pair, where the random draws are determined byseed.Examples
>>> import jax >>> from brainevent import JITCUniformR >>> >>> mat = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(4, 6)) >>> dense = mat.todense() >>> dense.shape (4, 6)
- transpose(axes=None)[source]#
Transpose the row-oriented matrix into a column-oriented matrix.
Returns a column-oriented matrix (
JITCUniformC) with rows and columns swapped, preserving the same weight bounds, probability, and seed values. The transpose operation effectively converts between row-oriented and column-oriented sparse matrix formats.- Parameters:
axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only
Noneis accepted.- Returns:
A new column-oriented uniform matrix with transposed dimensions.
- Return type:
- Raises:
AssertionError – If
axesis notNone, since partial axis transposition is not supported.
See also
JITCUniformC.transposeThe inverse operation.
Notes
The transpose satisfies
W.T[j, i] = W[i, j]. Since both the connectivity pattern and the uniform weights are deterministic functions ofseed, the transposed matrix produces identical results to materializingWand transposing the dense array.The
corderflag is flipped during transposition to maintain consistency with the underlying PRNG state ordering.Examples
>>> from brainevent import JITCUniformR, JITCUniformC >>> >>> row_matrix = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(30, 5)) >>> row_matrix.shape (30, 5) >>> col_matrix = row_matrix.transpose() >>> col_matrix.shape (5, 30) >>> isinstance(col_matrix, JITCUniformC) True