brainevent.jitsmv

Contents

brainevent.jitsmv#

brainevent.jitsmv = <NameScope(brainevent.jitsmv)>#

Perform the \(y=M@v\) or \(y=M.T@v\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.

In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight.

When transpose=True, we perform an operation of \(y=M^T@v\).

Note

Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True).

If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set corder=True, with the sacrifice of the speed compared with corder=False.

Parameters:
  • weight (Array | ndarray | Quantity | Number) – The value of the random matrix.

  • prob (float) – The connection probability.

  • vector (Array | ndarray | Quantity | Number) – The vector.

  • seed (int | None) – The random number generation seed.

  • shape (Tuple[int, int]) – The matrix shape.

  • transpose (bool) – Transpose the random matrix or not.

  • corder (bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimension

  • backend (str | None) – The computation backend to use. If None, the default backend is selected automatically.

Returns:

out – The output of \(y = M @ v\) if transpose=False, or the output of \(y = M^T @ v\) if transpose=True.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • AssertionError – If the matrix shape and vector length are incompatible.

See also

jits

Generate the full JIT scalar matrix as a dense array.

jitsmm

Matrix-matrix product with JIT-generated scalar matrix.

binary_jitsmv

Event-driven (binary) variant of this operation.

Notes

The operation computes:

y[i] = sum_{j in C(i)} w * v[j]

where w is the scalar weight, v is the input vector, and C(i) is the deterministic random connection set for row i (determined by the seed and connection probability). This is equivalent to y = M @ v where M[i, j] = w * Bernoulli(prob).

The weight w and vector v may carry physical units from brainunit; the output will have the product of their units.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_scalar.float import jitsmv
>>> v = jnp.ones(50)
>>> result = jitsmv(0.01, 0.1, v, seed=42, shape=(100, 50))
>>> result.shape  # (100,)