JITCUniformR#

class brainevent.JITCUniformR(low, high=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#

Just-In-Time Connectivity matrix with Row-oriented representation for uniform weight distributions.

This class implements a row-oriented sparse matrix optimized for JAX-based transformations, following the Compressed Sparse Row (CSR) format conceptually. Instead of storing all non-zero elements explicitly, it uses a uniform distribution with lower and upper bounds (wlow, whigh) to generate weights for connections, along with probability and seed information to determine the sparse structure.

The class is designed for efficient neural network connectivity patterns where weights follow a uniform distribution but connectivity is sparse and stochastic. The actual sparse structure and uniform weight values are generated just-in-time during operations.

wlow#

The lower bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.

Type:

Union[jax.Array, u.Quantity]

whigh#

The upper bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.

Type:

Union[jax.Array, u.Quantity]

prob#

Connection probability determining the sparsity of the matrix. Values range from 0 (no connections) to 1 (fully connected).

Type:

Union[float, jax.Array]

seed#

Random seed controlling the specific pattern of connections. Using the same seed produces identical connectivity patterns.

Type:

Union[int, jax.Array]

shape#

Tuple specifying the dimensions of the matrix as (rows, columns).

Type:

MatrixShape

corder#

Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).

Type:

bool

dtype#

The data type of the matrix elements (property inherited from parent).

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCUniformR

# Create a uniform matrix with bounds [0.1, 0.5], probability 0.2, and seed 42
>>> uniform_matrix = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(10, 10))
>>> uniform_matrix
JITCUniformR(shape=(10, 10), wlow=0.1, whigh=0.5, prob=0.2, seed=42, corder=False)

# Create a uniform matrix with units
>>> uniform_matrix_mv = JITCUniformR((0.1 * u.mV, 0.5 * u.mV, 0.2, 42), shape=(10, 10))

# Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = uniform_matrix @ vec
>>> # Each element in result is a weighted sum using uniformly distributed weights

# Apply scalar operation (scales both lower and upper bounds)
>>> scaled = uniform_matrix * 2.0
>>> print(scaled.wlow, scaled.whigh)  # 0.2 1.0

# Convert to dense representation
>>> dense_matrix = uniform_matrix.todense()
>>> # dense_matrix has shape (10, 10) with ~20% non-zero elements
>>> # each non-zero element is uniformly distributed between 0.1 and 0.5

# Transpose operation returns a JITCUniformC instance
>>> col_matrix = uniform_matrix.transpose()
>>> isinstance(col_matrix, JITCUniformC)  # True

# Update bounds while preserving connectivity pattern
>>> updated = uniform_matrix.with_data(0.2, 0.8)
>>> print(updated.wlow, updated.whigh)  # 0.2 0.8

# Use with JAX transformations
>>> @jax.jit
... def matrix_vector_product(mat, vec):
...     return mat @ vec
>>> result_jit = matrix_vector_product(uniform_matrix, vec)

Notes

The mathematical model for JITCUniformR is:

W[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)

Each entry W[i, j] is independently drawn from the continuous uniform distribution on [w_low, w_high] with probability prob, and zero otherwise. More precisely, the entry is computed as:

W[i, j] = U[i, j] * B[i, j]

where U[i, j] ~ Uniform(w_low, w_high) and B[i, j] ~ Bernoulli(prob) are independent random variables, both determined by seed.

The row-oriented representation means that the random number generator state is seeded per-row (or per-column, depending on corder), making row-based operations (W @ v) the natural direction.

Key properties:

  • JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)

  • More memory-efficient than dense matrices for sparse connectivity patterns

  • Well-suited for neural network connectivity matrices with uniformly distributed weights

  • Optimized for matrix-vector operations common in neural simulations

  • The actual matrix elements are never explicitly stored, only generated during operations

  • Using the same seed always produces the same random connectivity pattern and weights

todense()[source]#

Convert the sparse uniform matrix to a dense array.

Generates a full dense representation of the sparse matrix by sampling Uniform(w_low, w_high) values for all connections determined by the probability and seed. The resulting dense matrix preserves all the numerical properties of the sparse representation.

Returns:

A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a u.Quantity), the returned array will have the same units.

Return type:

Array | Quantity

Raises:

None – This method does not raise exceptions under normal use.

See also

JITCUniformC.todense

Column-oriented variant.

jitu

Standalone function to materialize JIT uniform matrices.

Notes

The dense matrix is generated according to:

dense[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)

for each (i, j) pair, where the random draws are determined by seed.

Examples

>>> import jax
>>> from brainevent import JITCUniformR
>>>
>>> mat = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(4, 6))
>>> dense = mat.todense()
>>> dense.shape
(4, 6)
transpose(axes=None)[source]#

Transpose the row-oriented matrix into a column-oriented matrix.

Returns a column-oriented matrix (JITCUniformC) with rows and columns swapped, preserving the same weight bounds, probability, and seed values. The transpose operation effectively converts between row-oriented and column-oriented sparse matrix formats.

Parameters:

axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.

Returns:

A new column-oriented uniform matrix with transposed dimensions.

Return type:

JITCUniformC

Raises:

AssertionError – If axes is not None, since partial axis transposition is not supported.

See also

JITCUniformC.transpose

The inverse operation.

Notes

The transpose satisfies W.T[j, i] = W[i, j]. Since both the connectivity pattern and the uniform weights are deterministic functions of seed, the transposed matrix produces identical results to materializing W and transposing the dense array.

The corder flag is flipped during transposition to maintain consistency with the underlying PRNG state ordering.

Examples

>>> from brainevent import JITCUniformR, JITCUniformC
>>>
>>> row_matrix = JITCUniformR((0.1, 0.5, 0.2, 42), shape=(30, 5))
>>> row_matrix.shape
(30, 5)
>>> col_matrix = row_matrix.transpose()
>>> col_matrix.shape
(5, 30)
>>> isinstance(col_matrix, JITCUniformC)
True