JITCNormalC#

class brainevent.JITCNormalC(loc, scale=None, prob=None, seed=None, *, shape, corder=False, backend=None, buffers=None)#

Just-In-Time Connectivity Normal distribution matrix with Column-oriented representation.

This class represents a column-oriented sparse matrix optimized for JAX-based transformations where non-zero elements follow a normal distribution. It follows the Compressed Sparse Column (CSC) format conceptually, storing location (mean) and scale (standard deviation) parameters for the normal distribution, along with probability and seed information to determine the sparse structure.

The column-oriented structure makes column-based operations more efficient than row-based ones, making this class the transpose-oriented counterpart to JITCNormalR.

wloc#

The location (mean) parameter of the normal distribution for non-zero elements.

Type:

Union[jax.Array, u.Quantity]

wscale#

The scale (standard deviation) parameter of the normal distribution for non-zero elements.

Type:

Union[jax.Array, u.Quantity]

prob#

Probability for each potential connection.

Type:

Union[float, jax.Array]

seed#

Random seed used for initialization of the sparse structure.

Type:

Union[int, jax.Array]

shape#

The shape of the matrix as a tuple (rows, cols).

Type:

MatrixShape

corder#

Flag indicating the memory layout order of the matrix. False (default) for Fortran-order (column-major), True for C-order (row-major).

Type:

bool

dtype#

The data type of the matrix elements (property inherited from parent).

Examples

>>> import jax
>>> import brainunit as u
>>> from brainevent import JITCNormalC

# Create a normal distribution matrix with mean 1.5, std 0.2, probability 0.1, and seed 42
>>> normal_matrix = JITCNormalC((1.5, 0.2, 0.1, 42), shape=(10, 10))
>>> normal_matrix
JITCNormalC(shape=(10, 10), wloc=1.5, wscale=0.2, prob=0.1, seed=42, corder=False)

>>> # Perform matrix-vector multiplication
>>> vec = jax.numpy.ones(10)
>>> result = normal_matrix @ vec

>>> # Apply scalar operation
>>> scaled = normal_matrix * 2.0

>>> # Convert to dense representation
>>> dense_matrix = normal_matrix.todense()

>>> # Transpose operation returns a JITCNormalR instance
>>> row_matrix = normal_matrix.transpose()

Notes

The mathematical model is the same as JITCNormalR:

W[i, j] = Normal(mu, sigma) * Bernoulli(p)

where mu = wloc, sigma = wscale, and p = prob. The column-oriented representation means that JITCNormalC is conceptually the transpose of a JITCNormalR matrix with swapped dimensions.

Key properties:

  • JAX PyTree compatible for use with JAX transformations (jit, grad, vmap)

  • More memory-efficient than dense matrices for sparse connectivity patterns

  • More efficient than JITCNormalR for column-based operations

  • Well-suited for neural network connectivity matrices with normally distributed weights

  • The matrix is implicitly constructed based on the probability and seed; the actual sparse structure is materialized only when needed

todense()[source]#

Convert the sparse column-oriented normal-weight matrix to dense format.

Generates a full dense representation where each non-zero entry is drawn from Normal(wloc, wscale) at positions determined by the probability and seed. The generated dense matrix always has self.shape.

Parameters:

None

Returns:

A dense matrix with the same shape as the sparse matrix. The data type will match the weight’s data type, and if the weight has units (is a u.Quantity), the returned array will have the same units.

Return type:

Array | Quantity

Raises:

None

See also

jitn

The underlying function that materializes the matrix.

Notes

The dense matrix is generated element-wise as:

dense[i, j] = Normal(mu, sigma)  if  hash(seed, i, j) < p  else  0

Examples

>>> from brainevent import JITCNormalC
>>> sparse_matrix = JITCNormalC((1.5, 0.2, 0.5, 42), shape=(3, 10))
>>> dense_matrix = sparse_matrix.todense()
>>> dense_matrix.shape  # (3, 10)
transpose(axes=None)[source]#

Transposes the column-oriented matrix into a row-oriented matrix.

This method returns a row-oriented matrix (JITCNormalR) with rows and columns swapped, preserving the same weight parameters (wloc, wscale), probability, and seed values. The transpose operation effectively converts between column-oriented and row-oriented sparse matrix formats.

Parameters:

axes (None) – Not supported. This parameter exists for compatibility with the NumPy API but only None is accepted.

Returns:

A new row-oriented normal distribution matrix with transposed dimensions.

Return type:

JITCNormalR

Raises:

AssertionError – If axes is not None, since partial axis transposition is not supported.

See also

JITCNormalR

Row-oriented counterpart.

Notes

The transpose preserves the mathematical identity:

JITCNormalC(shape=(m, n)).transpose().todense() == JITCNormalC(shape=(m, n)).todense().T

Examples

>>> from brainevent import JITCNormalC
>>>
>>> # Create a column-oriented matrix
>>> col_matrix = JITCNormalC((1.5, 0.2, 0.5, 42), shape=(3, 5))
>>> print(col_matrix.shape)  # (3, 5)
>>>
>>> # Transpose to row-oriented matrix
>>> row_matrix = col_matrix.transpose()
>>> print(row_matrix.shape)  # (5, 3)
>>> isinstance(row_matrix, JITCNormalR)  # True