brainevent.jitsmm

Contents

brainevent.jitsmm#

brainevent.jitsmm = <NameScope(brainevent.jitsmm)>#

Perform the \(y=M@B\) or \(y=M.T@B\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.

In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight. When transpose=True, we perform an operation of \(y=M^T@B\).

Note

Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True). If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set corder=True, with the sacrifice of the speed compared with corder=False.

Parameters:
  • weight (Array | ndarray | Quantity | Number) – The value of the random matrix.

  • prob (float) – The connection probability.

  • B (Array | ndarray | Quantity | Number) – The matrix.

  • seed (int | None) – The random number generation seed.

  • shape (Tuple[int, int]) – The matrix shape.

  • transpose (bool) – Transpose the random matrix or not.

  • corder (bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimension

  • backend (str | None) – The computation backend to use. If None, the default backend is selected automatically.

Returns:

out – The output of \(y = M @ B\) if transpose=False, or the output of \(y = M^T @ B\) if transpose=True.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • AssertionError – If the matrix shape and input matrix B dimensions are incompatible.

See also

jits

Generate the full JIT scalar matrix as a dense array.

jitsmv

Matrix-vector product with JIT-generated scalar matrix.

binary_jitsmm

Event-driven (binary) variant of this operation.

Notes

The operation computes:

Y[i, k] = sum_{j in C(i)} w * B[j, k]

where w is the scalar weight, B is the input matrix, and C(i) is the deterministic random connection set for row i. This is equivalent to Y = M @ B where M[i, j] = w * Bernoulli(prob).

This is mathematically equivalent to performing jitsmv for each column of B, but is implemented more efficiently as a single kernel.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_scalar.float import jitsmm
>>> B = jnp.ones((50, 10))
>>> result = jitsmm(0.01, 0.1, B, seed=42, shape=(100, 50))
>>> result.shape  # (100, 10)