JITCUniformC#

class brainevent.JITCUniformC(low, high=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#

Just-In-Time Connectivity matrix with Column-oriented representation for uniform weight distributions.

This class implements a column-oriented sparse matrix optimized for JAX-based transformations, following the Compressed Sparse Column (CSC) format conceptually. Instead of storing all non-zero elements explicitly, it uses a uniform distribution with lower and upper bounds (wlow, whigh) to generate weights for connections, along with probability and seed information to determine the sparse structure.

The class is designed for efficient neural network connectivity patterns where weights follow a uniform distribution but connectivity is sparse and stochastic. The column-oriented structure makes column-based operations more efficient than row-based ones, making this class the transpose-oriented counterpart to JITCUniformR.

wlow#

The lower bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.

Type:

Union[jax.Array, u.Quantity]

whigh#

The upper bound of the uniform distribution for non-zero elements. Can be a plain JAX array or a quantity with units.

Type:

Union[jax.Array, u.Quantity]

prob#

Connection probability determining the sparsity of the matrix. Values range from 0 (no connections) to 1 (fully connected).

Type:

Union[float, jax.Array]

seed#

Random seed controlling the specific pattern of connections. Using the same seed produces identical connectivity patterns.

Type:

Union[int, jax.Array]

shape#

Tuple specifying the dimensions of the matrix as (rows, columns).

Type:

MatrixShape

corder#

Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).

Type:

bool

dtype#

The data type of the matrix elements (property inherited from parent).

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCUniformC

# Create a uniform matrix with bounds [0.1, 0.5], probability 0.2, and seed 42
>>> uniform_matrix = JITCUniformC((0.1, 0.5, 0.2, 42), shape=(10, 10))
>>> uniform_matrix
JITCUniformC(shape=(10, 10), wlow=0.1, whigh=0.5, prob=0.2, seed=42, corder=False)

# Create a uniform matrix with units
>>> uniform_matrix_mv = JITCUniformC((0.1 * u.mV, 0.5 * u.mV, 0.2, 42), shape=(10, 10))

# Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = uniform_matrix @ vec
>>> # Each element in result is a weighted sum using uniformly distributed weights

# Apply scalar operation (scales both lower and upper bounds)
>>> scaled = uniform_matrix * 2.0
>>> print(scaled.wlow, scaled.whigh)  # 0.2 1.0

# Convert to dense representation
>>> dense_matrix = uniform_matrix.todense()
>>> # dense_matrix has shape (10, 10) with ~20% non-zero elements
>>> # each non-zero element is uniformly distributed between 0.1 and 0.5

# Transpose operation returns a JITCUniformR instance
>>> row_matrix = uniform_matrix.transpose()
>>> isinstance(row_matrix, JITCUniformR)  # True

# Update bounds while preserving connectivity pattern
>>> updated = uniform_matrix.with_data(0.2, 0.8)
>>> print(updated.wlow, updated.whigh)  # 0.2 0.8

# Use with JAX transformations
>>> @jax.jit
... def matrix_vector_product(mat, vec):
...     return mat @ vec
>>> result_jit = matrix_vector_product(uniform_matrix, vec)

# Matrix-matrix multiplication
>>> mat = jax.numpy.ones((10, 5))
>>> result_mat = uniform_matrix @ mat
>>> result_mat.shape  # (10, 5)

# Right matrix multiplication
>>> mat = jax.numpy.ones((5, 10))
>>> result_rmat = mat @ uniform_matrix
>>> result_rmat.shape  # (5, 10)

Notes

The mathematical model for JITCUniformC is:

W[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)

Each entry W[i, j] is independently drawn from the continuous uniform distribution on [w_low, w_high] with probability prob, and zero otherwise. More precisely:

W[i, j] = U[i, j] * B[i, j]

where U[i, j] ~ Uniform(w_low, w_high) and B[i, j] ~ Bernoulli(prob) are independent random variables, both determined by seed.

The column-oriented representation is the transpose dual of JITCUniformR. Internally, operations on JITCUniformC are delegated to the transposed JITCUniformR form: JITCUniformC @ v == JITCUniformR.T @ v.

Key properties:

  • JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)

  • More memory-efficient than dense matrices for sparse connectivity patterns

  • Well-suited for neural network connectivity matrices with uniformly distributed weights

  • The column-oriented structure makes column-slicing operations more efficient

  • Optimized for matrix-vector operations common in neural simulations

  • The actual matrix elements are never explicitly stored, only generated during operations

  • Using the same seed always produces the same random connectivity pattern and weights

todense()[source]#

Convert the sparse column-oriented uniform matrix to a dense array.

Generates a full dense representation of the sparse matrix by sampling Uniform(w_low, w_high) values for all connections determined by the probability and seed.

Returns:

A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a u.Quantity), the returned array will have the same units.

Return type:

Array | Quantity

Raises:

None – This method does not raise exceptions under normal use.

See also

JITCUniformR.todense

Row-oriented variant.

jitu

Standalone function to materialize JIT uniform matrices.

Notes

The dense matrix is generated according to:

dense[i, j] = Uniform(w_low, w_high) * Bernoulli(prob)

for each (i, j) pair, where the random draws are determined by seed.

Examples

>>> from brainevent import JITCUniformC
>>>
>>> mat = JITCUniformC((0.1, 0.5, 0.2, 42), shape=(3, 10))
>>> dense = mat.todense()
>>> dense.shape
(3, 10)
transpose(axes=None)[source]#

Transpose the column-oriented matrix into a row-oriented matrix.

Returns a row-oriented matrix (JITCUniformR) with rows and columns swapped, preserving the same weight bounds, probability, and seed values.

Parameters:

axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.

Returns:

A new row-oriented uniform matrix with transposed dimensions.

Return type:

JITCUniformR

Raises:

AssertionError – If axes is not None, since partial axis transposition is not supported.

See also

JITCUniformR.transpose

The inverse operation.

Notes

The transpose satisfies W.T[j, i] = W[i, j]. The corder flag is flipped during transposition to maintain consistency with the underlying PRNG state ordering.

Examples

>>> from brainevent import JITCUniformC, JITCUniformR
>>>
>>> col_matrix = JITCUniformC((0.1, 0.5, 0.2, 42), shape=(3, 5))
>>> col_matrix.shape
(3, 5)
>>> row_matrix = col_matrix.transpose()
>>> row_matrix.shape
(5, 3)
>>> isinstance(row_matrix, JITCUniformR)
True