brainevent.binary_jitsmv

Contents

brainevent.binary_jitsmv#

brainevent.binary_jitsmv = <NameScope(brainevent.binary_jitsmv)>#

Perform the \(y=M@v\) or \(y=M.T@v\) operation, where \(M\) is just-in-time randomly generated with a scalar weight at each position.

In this operation, \(M\) is the random matrix with a connection probability conn_prob, and at each connection the value is the same scalar weight.

When transpose=True, we perform an operation of \(y=M^T@v\).

Note

Note that the just-in-time generated \(M\) (transpose=False) is different from the generated \(M^T\) (transpose=True).

If you pursue the same \(M\) and \(M^T\) when performing the just-in-time matrix generation, you should set corder=True, with the sacrifice of the speed compared with corder=False.

Parameters:
  • weight (Array | ndarray | Quantity | Number) – The value of the random matrix.

  • prob (float) – The connection probability.

  • vector (Array | ndarray | Quantity | Number) – The vector.

  • seed (int | None) – The random number generation seed.

  • shape (Tuple[int, int]) – The matrix shape.

  • transpose (bool) – Transpose the random matrix or not.

  • corder (bool) – Controls whether the parallelization order is oriented along the matrix columns: - True: Sampling index along collum dimension - False: Sampling index along row dimension

  • backend (str | None) – The computation backend to use. If None, the default backend is selected automatically.

Returns:

out – The output of \(y = M @ v\) if transpose=False, or the output of \(y = M^T @ v\) if transpose=True.

Return type:

Array | ndarray | Quantity | Number

Raises:
  • ValueError – If prob is not a scalar, is not finite, or is outside [0, 1].

  • AssertionError – If the matrix shape and vector length are incompatible.

See also

binary_jitsmm

Event-driven matrix-matrix multiplication with scalar weight.

jitsmv

Float matrix-vector multiplication with scalar weight.

Notes

This function computes an event-driven (spike-based) matrix-vector product where the connectivity matrix M has the structure:

M[i, j] = w * Bernoulli(prob)

and the input vector is treated as a binary event vector (spikes). The output for each element is:

y[i] = sum_{j in C(i)} w * spike[j]

where C(i) is the deterministic random connection set for row i (determined by the seed), and spike[j] is 1 if vector[j] is True (for boolean) or > 0 (for float).

Since the input is binary, the operation reduces to counting the number of active (spiking) presynaptic neurons that connect to each postsynaptic neuron, then scaling by w:

y[i] = w * |{j in C(i) : spike[j] = 1}|

The matrix is never materialized in memory. The connectivity pattern is regenerated on-the-fly using the seed and connection length parameter clen = 2 / prob.

Examples

>>> import jax.numpy as jnp
>>> from brainevent._jit_scalar.binary import binary_jitsmv
>>> weight = 0.5
>>> events = jnp.array([True, False, True, True, False])
>>> result = binary_jitsmv(weight, 0.5, events, seed=42,
...                        shape=(3, 5))
>>> result.shape  # (3,)