JITCScalarR#

class brainevent.JITCScalarR(weight, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#

Just-In-Time Connectivity Homogeneous matrix with Row-oriented representation.

This class represents a row-oriented homogeneous sparse matrix optimized for JAX-based transformations. It follows the Compressed Sparse Row (CSR) format conceptually, storing a uniform weight value for all non-zero elements in the matrix, along with probability and seed information to determine the sparse structure.

The class is designed for efficient neural network connectivity patterns where weights are homogeneous (identical) but connectivity is sparse and stochastically determined. The row-oriented structure makes row-based operations more efficient than column-based ones.

weight#

The homogeneous value used for all non-zero elements in the matrix. Can be a plain JAX array or a quantity with units.

Type:

Union[jax.Array, u.Quantity]

prob#

Probability for each potential connection. Controls the sparsity level with 0.0 meaning no connections and 1.0 meaning all possible connections.

Type:

Union[float, jax.Array]

seed#

Random seed used for initialization of the sparse structure. Using the same seed produces identical connectivity patterns.

Type:

Union[int, jax.Array]

shape#

The shape of the matrix as a tuple (rows, cols).

Type:

MatrixShape

corder#

Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).

Type:

bool

dtype#

The data type of the matrix elements (property inherited from parent).

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarR

# Create a homogeneous matrix with value 1.5, probability 0.1, and seed 42
>>> homo_matrix = JITCScalarR((1.5, 0.1, 42), shape=(10, 10))
>>> homo_matrix
JITCHomoR(shape=(10, 10), weight=1.5, prob=0.1, seed=42, corder=False)

# Create a matrix with units
>>> weighted_matrix = JITCScalarR((1.5 * u.mV, 0.1, 42), shape=(10, 10))
>>> weighted_matrix
JITCHomoR(shape=(10, 10), weight=1.5 mV, prob=0.1, seed=42, corder=False)

# Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = homo_matrix @ vec
>>> result.shape  # (10,)

# Apply scalar operations
>>> scaled = homo_matrix * 2.0
>>> scaled.weight  # 3.0

# Arithmetic operations maintain the sparse structure
>>> neg_matrix = -homo_matrix
>>> neg_matrix.weight  # -1.5

# Convert to dense representation
>>> dense_matrix = homo_matrix.todense()
>>> dense_matrix.shape  # (10, 10)

# Transpose operation returns a column-oriented matrix
>>> col_matrix = homo_matrix.transpose()
>>> isinstance(col_matrix, JITCScalarC)  # True
>>> col_matrix.shape  # (10, 10)

Notes

The mathematical model for this matrix is:

W[i, j] = w * Bernoulli(p)

where w is the scalar weight (self.weight), p is the connection probability (self.prob), and the Bernoulli draw is determined by a deterministic hash from the seed. The expected value of each element is E[W[i, j]] = w * p and the variance is Var[W[i, j]] = w^2 * p * (1 - p).

For a matrix-vector product y = W @ x:

y[i] = sum_{j in C(i)} w * x[j]

where C(i) is the deterministic random connection set for row i, with |C(i)| ~ Binomial(n_cols, p).

Key properties:

  • JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)

  • More memory-efficient than dense matrices for sparse connectivity patterns

  • Well-suited for neural network connectivity matrices with uniform weights

  • Optimized for matrix-vector operations common in neural simulations

  • The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed

  • When used with units (e.g., u.mV), units are preserved through operations

See also

JITCScalarC

Column-oriented counterpart of this class.

JITCScalarMatrix

Base class providing shared functionality.

todense()[source]#

Convert the sparse scalar-weight matrix to dense format.

Generates a full dense representation of the sparse matrix by materializing all entries W[i, j] = w * Bernoulli(p) determined by the probability and seed.

Parameters:

None

Returns:

A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a u.Quantity), the returned array will have the same units.

Return type:

Array | Quantity

Raises:

None

See also

jits

The underlying function that materializes the matrix.

Notes

The dense matrix is generated by iterating over all (i, j) positions and placing the scalar weight w at each position where the deterministic PRNG indicates a connection:

dense[i, j] = w  if  hash(seed, i, j) < p  else  0

Examples

>>> import brainunit as u
>>> from brainevent import JITCScalarR
>>> sparse_matrix = JITCScalarR((1.5 * u.mV, 0.5, 42), shape=(10, 4))
>>> dense_matrix = sparse_matrix.todense()
>>> dense_matrix.shape  # (10, 4)
transpose(axes=None)[source]#

Transposes the row-oriented matrix into a column-oriented matrix.

This method returns a column-oriented matrix (JITCScalarC) with rows and columns swapped, preserving the same weight, probability, and seed values. The transpose operation effectively converts between row-oriented and column-oriented sparse matrix formats.

Parameters:

axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.

Returns:

A new column-oriented homogeneous matrix with transposed dimensions.

Return type:

JITCScalarC

Raises:

AssertionError – If axes is not None, since partial axis transposition is not supported.

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCScalarR
>>>
>>> # Create a row-oriented matrix
>>> row_matrix = JITCScalarR((1.5, 0.5, 42), shape=(30, 5))
>>> print(row_matrix.shape)  # (30, 5)
>>>
>>> # Transpose to column-oriented matrix
>>> col_matrix = row_matrix.transpose()
>>> print(col_matrix.shape)  # (5, 30)
>>> isinstance(col_matrix, JITCScalarC)  # True