JITCNormalR#

class brainevent.JITCNormalR(loc, scale=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#

Just-In-Time Connectivity Normal distribution matrix with Row-oriented representation.

This class represents a row-oriented sparse matrix optimized for JAX-based transformations where non-zero elements follow a normal distribution. It follows the Compressed Sparse Row (CSR) format conceptually, storing location (mean) and scale (standard deviation) parameters for the normal distribution, along with probability and seed information to determine the sparse structure.

The class is designed for efficient neural network connectivity patterns where weights follow a normal distribution but connectivity is sparse and stochastic.

wloc#

The location (mean) parameter of the normal distribution for non-zero elements.

Type:

Union[jax.Array, u.Quantity]

wscale#

The scale (standard deviation) parameter of the normal distribution for non-zero elements.

Type:

Union[jax.Array, u.Quantity]

prob#

Probability for each potential connection.

Type:

Union[float, jax.Array]

seed#

Random seed used for initialization of the sparse structure.

Type:

Union[int, jax.Array]

shape#

The shape of the matrix as a tuple (rows, cols).

Type:

MatrixShape

corder#

Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).

Type:

bool

dtype#

The data type of the matrix elements (property inherited from parent).

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCNormalR

# Create a normal distribution matrix with mean 1.5, std 0.2, probability 0.1, and seed 42
>>> normal_matrix = JITCNormalR((1.5, 0.2, 0.1, 42), shape=(10, 10))
>>> normal_matrix
JITCNormalR(shape=(10, 10), wloc=1.5, wscale=0.2, prob=0.1, seed=42, corder=False)

>>> # Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = normal_matrix @ vec

>>> # Apply scalar operation
>>> scaled = normal_matrix * 2.0

>>> # Convert to dense representation
>>> dense_matrix = normal_matrix.todense()

>>> # Transpose operation returns a JITCNormalC instance
>>> col_matrix = normal_matrix.transpose()

Notes

The mathematical model for this matrix is:

W[i, j] = Normal(mu, sigma) * Bernoulli(p)

where mu is wloc, sigma is wscale, p is prob, and the Bernoulli and Normal draws are both determined by the seed. For a matrix-vector product y = W @ x:

y[i] = sum_{j in C(i)} N_ij * x[j]

where C(i) is the deterministic random connection set for row i and N_ij ~ Normal(mu, sigma) is the weight for that connection.

Key properties:

  • JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)

  • More memory-efficient than dense matrices for sparse connectivity patterns

  • Well-suited for neural network connectivity matrices with normally distributed weights

  • Optimized for matrix-vector operations common in neural simulations

  • The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed

todense()[source]#

Convert the sparse normal-weight matrix to dense format.

Generates a full dense representation where each non-zero entry is drawn from Normal(wloc, wscale) at positions determined by the probability and seed.

Parameters:

None

Returns:

A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a u.Quantity), the returned array will have the same units.

Return type:

Array | Quantity

Raises:

None

See also

jitn

The underlying function that materializes the matrix.

Notes

The dense matrix is generated element-wise as:

dense[i, j] = Normal(mu, sigma)  if  hash(seed, i, j) < p  else  0

where mu = wloc, sigma = wscale, and p = prob.

Examples

>>> import brainunit as u
>>> from brainevent import JITCNormalR
>>> sparse_matrix = JITCNormalR((1.5, 0.2, 0.5, 42), shape=(10, 4))
>>> dense_matrix = sparse_matrix.todense()
>>> dense_matrix.shape  # (10, 4)
transpose(axes=None)[source]#

Transpose the row-oriented matrix into a column-oriented matrix.

Returns a column-oriented matrix (JITCNormalC) with rows and columns swapped, preserving the same weight parameters, probability, and seed.

Parameters:

axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.

Returns:

A new column-oriented normal-weight matrix with transposed dimensions.

Return type:

JITCNormalC

Raises:

AssertionError – If axes is not None, since partial axis transposition is not supported.

See also

JITCNormalC.transpose

The inverse operation.

Notes

The transpose swaps the shape and inverts the corder flag so that the same PRNG sequence is used, ensuring mat.transpose().todense() equals mat.todense().T.

Examples

>>> from brainevent import JITCNormalR
>>> row_matrix = JITCNormalR((1.5, 0.2, 0.5, 42), shape=(30, 5))
>>> col_matrix = row_matrix.transpose()
>>> col_matrix.shape  # (5, 30)
>>> isinstance(col_matrix, JITCNormalC)  # True