BinaryArray#
- class brainevent.BinaryArray(value)#
Array wrapper for binary (0/1) event vectors and matrices.
BinaryArrayrepresents a boolean or 0/1 array and provides event-driven matrix multiplication via the@operator. When aBinaryArrayis multiplied with a dense weight matrix, only the rows/columns corresponding to non-zero (active) elements are accumulated, which is mathematically equivalent to standard matrix multiplication but can exploit sparsity for efficiency.- Parameters:
value (array_like) – The input binary array data. Typically a boolean JAX array, but any array whose non-zero pattern encodes binary events is accepted.
Notes
Given a binary spike vector
sof shape(k,)and a weight matrixWof shape(k, n), the forward multiplications @ Wcomputes:y[j] = sum_{i : s[i] != 0} W[i, j]
This is equivalent to
s.astype(float) @ Wbut the implementation only iterates over the non-zero entries ofs.The class is registered as a JAX PyTree node, so it is compatible with
jax.jit,jax.grad,jax.vmap, and other transformations.See also
binary_densemvUnderlying primitive for binary vector-matrix multiply.
binary_densemmUnderlying primitive for binary matrix-matrix multiply.
Examples
>>> import jax.numpy as jnp >>> import brainevent as be >>> spikes = be.BinaryArray(jnp.array([True, False, True])) >>> W = jnp.ones((3, 4)) >>> spikes @ W # sums rows 0 and 2 of W Array([2., 2., 2., 2.], dtype=float32)
- property T#
Transpose of the underlying array.
- Returns:
The transposed raw array (not wrapped in
BinaryArray).- Return type:
jax.Array
- bitpack()[source]#
Pack binary values into uint32 words along every axis.
Each uint32 word stores 32 binary values. Bit
bof wordwcorresponds to elementw * 32 + balong the packed axis.Creates and returns a new
BitPackedBinaryinstance. The packed representation can be used with FCN sparse matrices for improved GPU cache utilisation.- Returns:
A new bit-packed event representation with one packed array per axis.
- Return type:
BitPackedBinary
Examples
>>> import jax.numpy as jnp >>> import brainevent as be >>> ba = be.BinaryArray(jnp.array([True, False, True])) >>> bp = ba.bitpack() >>> type(bp) <class 'brainevent.BitPackedBinary'>
- transpose(*axes)[source]#
Return the underlying array with axes permuted.
- Parameters:
*axes (int, optional) – Axis permutation. If omitted, reverses the axis order.
- Returns:
The transposed raw array (not wrapped in
BinaryArray).- Return type:
jax.Array
- tree_flatten()[source]#
Flatten this instance for JAX PyTree serialisation.
- Returns:
children (tuple) – A single-element tuple
(value,)containing the dynamic array leaf.aux_data (dict) – Empty dictionary (no static metadata).