BinaryArray#

class brainevent.BinaryArray(value)#

Array wrapper for binary (0/1) event vectors and matrices.

BinaryArray represents a boolean or 0/1 array and provides event-driven matrix multiplication via the @ operator. When a BinaryArray is multiplied with a dense weight matrix, only the rows/columns corresponding to non-zero (active) elements are accumulated, which is mathematically equivalent to standard matrix multiplication but can exploit sparsity for efficiency.

Parameters:

value (array_like) – The input binary array data. Typically a boolean JAX array, but any array whose non-zero pattern encodes binary events is accepted.

Notes

Given a binary spike vector s of shape (k,) and a weight matrix W of shape (k, n), the forward multiplication s @ W computes:

y[j] = sum_{i : s[i] != 0} W[i, j]

This is equivalent to s.astype(float) @ W but the implementation only iterates over the non-zero entries of s.

The class is registered as a JAX PyTree node, so it is compatible with jax.jit, jax.grad, jax.vmap, and other transformations.

See also

binary_densemv

Underlying primitive for binary vector-matrix multiply.

binary_densemm

Underlying primitive for binary matrix-matrix multiply.

Examples

>>> import jax.numpy as jnp
>>> import brainevent as be
>>> spikes = be.BinaryArray(jnp.array([True, False, True]))
>>> W = jnp.ones((3, 4))
>>> spikes @ W  # sums rows 0 and 2 of W
Array([2., 2., 2., 2.], dtype=float32)
property T#

Transpose of the underlying array.

Returns:

The transposed raw array (not wrapped in BinaryArray).

Return type:

jax.Array

bitpack()[source]#

Pack binary values into uint32 words along every axis.

Each uint32 word stores 32 binary values. Bit b of word w corresponds to element w * 32 + b along the packed axis.

Creates and returns a new BitPackedBinary instance. The packed representation can be used with FCN sparse matrices for improved GPU cache utilisation.

Returns:

A new bit-packed event representation with one packed array per axis.

Return type:

BitPackedBinary

Examples

>>> import jax.numpy as jnp
>>> import brainevent as be
>>> ba = be.BinaryArray(jnp.array([True, False, True]))
>>> bp = ba.bitpack()
>>> type(bp)
<class 'brainevent.BitPackedBinary'>
transpose(*axes)[source]#

Return the underlying array with axes permuted.

Parameters:

*axes (int, optional) – Axis permutation. If omitted, reverses the axis order.

Returns:

The transposed raw array (not wrapped in BinaryArray).

Return type:

jax.Array

tree_flatten()[source]#

Flatten this instance for JAX PyTree serialisation.

Returns:

  • children (tuple) – A single-element tuple (value,) containing the dynamic array leaf.

  • aux_data (dict) – Empty dictionary (no static metadata).

classmethod tree_unflatten(aux_data, flat_contents)[source]#

Reconstruct a BinaryArray from its PyTree representation.

Parameters:
  • aux_data (dict) – Static metadata produced by tree_flatten.

  • flat_contents (tuple) – Dynamic leaves — the underlying array.

Returns:

A new instance wrapping the given array.

Return type:

BinaryArray