brainevent.binary_jitsmm

Contents

brainevent.binary_jitsmm#

brainevent.binary_jitsmm = <NameScope(brainevent.binary_jitsmm)>#

Perform the \(y=M@B\) or \(y=M.T@B\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.

In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight. When transpose=True, we perform an operation of \(y=M^T@B\).

Note

Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True). If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set corder=True, with the sacrifice of the speed compared with corder=False.

Parameters:
  • weight (Array | ndarray | Quantity | Number) – The value of the random matrix.

  • prob (float) – The connection probability.

  • B (Array | ndarray | Quantity | Number) – The matrix.

  • seed (int | None) – The random number generation seed.

  • shape (Tuple[int, int]) – The matrix shape.

  • transpose (bool) – Transpose the random matrix or not.

  • corder (bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimension

  • backend (str | None) – The computation backend to use. If None, the default backend is selected automatically.

Returns:

out – The output of \(y = M @ B\) if transpose=False, or the output of \(y = M^T @ B\) if transpose=True.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • AssertionError – If the matrix shape and input matrix B dimensions are incompatible.

See also

binary_jitsmv

Event-driven matrix-vector multiplication with scalar weight.

jitsmm

Float matrix-matrix multiplication with scalar weight.

Notes

This function computes an event-driven (spike-based) matrix-matrix product where the connectivity matrix M has the structure:

M[i, j] = w * Bernoulli(prob)

and the input matrix B is treated as a binary event matrix (each column is a spike vector). For each output element:

Y[i, k] = sum_{j in C(i)} w * spike[j, k]

where C(i) is the deterministic random connection set for row i and spike[j, k] is 1 if B[j, k] is True (for boolean) or > 0 (for float).

This is equivalent to performing binary_jitsmv independently for each column of B, but is implemented more efficiently as a single kernel.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_scalar.binary import binary_jitsmm
>>> weight = 0.5
>>> B = jnp.array([[True, False], [False, True], [True, True],
...                [False, False], [True, False]])
>>> result = binary_jitsmm(weight, 0.5, B, seed=42,
...                        shape=(3, 5))
>>> result.shape  # (3, 2)