brainevent.jitsmv#
- brainevent.jitsmv = <NameScope(brainevent.jitsmv)>#
Perform the \(y=M@v\) or \(y=M.T@v\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.
In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight.
When
transpose=True, we perform an operation of \(y=M^T@v\).Note
Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True).
If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set
corder=True, with the sacrifice of the speed compared withcorder=False.- Parameters:
weight (
Array|ndarray|Quantity|Number) – The value of the random matrix.prob (
float) – The connection probability.vector (
Array|ndarray|Quantity|Number) – The vector.transpose (
bool) – Transpose the random matrix or not.corder (
bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimensionbackend (
str|None) – The computation backend to use. IfNone, the default backend is selected automatically.
- Returns:
out – The output of \(y = M @ v\) if
transpose=False, or the output of \(y = M^T @ v\) iftranspose=True.- Return type:
Array|ndarray|Quantity|Number- Raises:
ValueError – If
probis not a scalar, is not finite, or is outside[0, 1].AssertionError – If the matrix shape and vector length are incompatible.
See also
jitsGenerate the full JIT scalar matrix as a dense array.
jitsmmMatrix-matrix product with JIT-generated scalar matrix.
binary_jitsmvEvent-driven (binary) variant of this operation.
Notes
The operation computes:
y[i] = sum_{j in C(i)} w * v[j]where
wis the scalar weight,vis the input vector, andC(i)is the deterministic random connection set for rowi(determined by the seed and connection probability). This is equivalent toy = M @ vwhereM[i, j] = w * Bernoulli(prob).The weight
wand vectorvmay carry physical units frombrainunit; the output will have the product of their units.Examples
>>> import jax.numpy as jnp >>> from brainevent._jit_scalar.float import jitsmv >>> v = jnp.ones(50) >>> result = jitsmv(0.01, 0.1, v, seed=42, shape=(100, 50)) >>> result.shape # (100,)