brainevent.jitsmm#
- brainevent.jitsmm = <NameScope(brainevent.jitsmm)>#
Perform the \(y=M@B\) or \(y=M.T@B\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.
In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight. When
transpose=True, we perform an operation of \(y=M^T@B\).Note
Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True). If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set
corder=True, with the sacrifice of the speed compared withcorder=False.- Parameters:
weight (
Array|ndarray|Quantity|Number) – The value of the random matrix.prob (
float) – The connection probability.B (
Array|ndarray|Quantity|Number) – The matrix.transpose (
bool) – Transpose the random matrix or not.corder (
bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimensionbackend (
str|None) – The computation backend to use. IfNone, the default backend is selected automatically.
- Returns:
out – The output of \(y = M @ B\) if
transpose=False, or the output of \(y = M^T @ B\) iftranspose=True.- Return type:
Array|ndarray|Quantity|Number- Raises:
ValueError – If
probis not a scalar, is not finite, or is outside[0, 1].AssertionError – If the matrix shape and input matrix
Bdimensions are incompatible.
See also
jitsGenerate the full JIT scalar matrix as a dense array.
jitsmvMatrix-vector product with JIT-generated scalar matrix.
binary_jitsmmEvent-driven (binary) variant of this operation.
Notes
The operation computes:
Y[i, k] = sum_{j in C(i)} w * B[j, k]where
wis the scalar weight,Bis the input matrix, andC(i)is the deterministic random connection set for rowi. This is equivalent toY = M @ BwhereM[i, j] = w * Bernoulli(prob).This is mathematically equivalent to performing
jitsmvfor each column ofB, but is implemented more efficiently as a single kernel.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_scalar.float import jitsmm >>> B = jnp.ones((50, 10)) >>> result = jitsmm(0.01, 0.1, B, seed=42, shape=(100, 50)) >>> result.shape # (100, 10)