brainevent.binary_jitsmm#
- brainevent.binary_jitsmm = <NameScope(brainevent.binary_jitsmm)>#
Perform the \(y=M@B\) or \(y=M.T@B\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.
In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight. When
transpose=True, we perform an operation of \(y=M^T@B\).Note
Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True). If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set
corder=True, with the sacrifice of the speed compared withcorder=False.- Parameters:
weight (
Array|ndarray|Quantity|Number) – The value of the random matrix.prob (
float) – The connection probability.B (
Array|ndarray|Quantity|Number) – The matrix.transpose (
bool) – Transpose the random matrix or not.corder (
bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimensionbackend (
str|None) – The computation backend to use. IfNone, the default backend is selected automatically.
- Returns:
out – The output of \(y = M @ B\) if
transpose=False, or the output of \(y = M^T @ B\) iftranspose=True.- Return type:
Array|ndarray|Quantity|Number- Raises:
ValueError – If
probis not a scalar, is not finite, or is outside[0, 1].AssertionError – If the matrix shape and input matrix
Bdimensions are incompatible.
See also
binary_jitsmvEvent-driven matrix-vector multiplication with scalar weight.
jitsmmFloat matrix-matrix multiplication with scalar weight.
Notes
This function computes an event-driven (spike-based) matrix-matrix product where the connectivity matrix
Mhas the structure:M[i, j] = w * Bernoulli(prob)and the input matrix
Bis treated as a binary event matrix (each column is a spike vector). For each output element:Y[i, k] = sum_{j in C(i)} w * spike[j, k]where
C(i)is the deterministic random connection set for rowiandspike[j, k]is 1 ifB[j, k]is True (for boolean) or> 0(for float).This is equivalent to performing
binary_jitsmvindependently for each column ofB, but is implemented more efficiently as a single kernel.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_scalar.binary import binary_jitsmm >>> weight = 0.5 >>> B = jnp.array([[True, False], [False, True], [True, True], ... [False, False], [True, False]]) >>> result = binary_jitsmm(weight, 0.5, B, seed=42, ... shape=(3, 5)) >>> result.shape # (3, 2)