brainevent.jits

Contents

brainevent.jits#

brainevent.jits = <NameScope(brainevent.jits)>#

Generate a homogeneous sparse random matrix on-the-fly.

This function creates a sparse random matrix where all non-zero values are set to the same homogeneous weight. Instead of storing the full matrix in memory, this function efficiently represents it in a form that can be used with JAX transformations including jit(), vmap(), grad() and pmap().

Parameters:
  • weight (Array | ndarray | Quantity | Number) – The value to use for all non-zero entries in the matrix. Can be a scalar, an Array, ndarray, or a Quantity with units.

  • prob (float) – Connection probability for the matrix (between 0 and 1). Determines the sparsity of the generated matrix.

  • seed (int) – Random seed for reproducible matrix generation.

  • shape (Tuple[int, int]) – The shape of the matrix as a tuple (num_rows, num_cols).

  • transpose (bool) – If True, return the transposed random matrix.

  • corder (bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimension

  • backend (str | None) – The computation backend to use. If None, the default backend is selected automatically.

Returns:

The generated sparse random matrix with the specified shape. If transpose is True, the matrix is transposed, and the output shape is shape. Otherwise, the output shape is (shape[1], shape[0]).

Return type:

Array | ndarray | Quantity | Number

Raises:

ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

See also

jitsmv

Matrix-vector product with JIT-generated scalar matrix.

jitsmm

Matrix-matrix product with JIT-generated scalar matrix.

Notes

The matrix W is defined element-wise as:

W[i, j] = w * B[i, j]

where w is the scalar weight and B[i, j] ~ Bernoulli(prob) is a binary mask fully determined by the seed. The mask is generated using a deterministic PRNG that, for a given (seed, i, j) triple, always produces the same outcome.

The expected number of non-zeros is prob * m * n where (m, n) is the matrix shape. The connection length parameter clen = 2 / prob controls the average stride between successive non-zero entries during the sampling loop.

When using corder=True (default), the matrix generated with transpose=True will generally be different from the transpose of the matrix generated with transpose=False. Set corder=False if exact correspondence between these two cases is required.

Examples

>>> import jax.numpy as jnp
>>> import brainunit as u
>>> from brainevent._jit_scalar.float import jits
>>> # Generate a 1000x500 sparse matrix with 10% connection probability
>>> matrix = jits(0.01, prob=0.1, seed=42, shape=(1000, 500))
>>> matrix.shape  # (1000, 500)
>>> # With units
>>> matrix_u = jits(0.01 * u.mA, prob=0.1, seed=42, shape=(1000, 500))