EventRepresentation#

class brainevent.EventRepresentation(value)#

Abstract base class for event-driven array representations.

EventRepresentation wraps an underlying JAX array (or brainunit.Quantity) and exposes array-like properties (shape, ndim, dtype, size) while requiring subclasses to implement the @ (matrix multiplication) operator via __matmul__ and __rmatmul__.

The class is registered as a JAX PyTree node so that instances are compatible with jax.jit, jax.grad, and other JAX transformations.

Parameters:

value (EventRepresentation | Array | ndarray | Quantity | list | tuple | int | float | bool) – The underlying array data. Accepted types include jax.Array, numpy.ndarray, brainunit.Quantity, Python lists/tuples, or another EventRepresentation (whose inner value will be extracted).

Notes

Event-driven computation exploits the sparsity of neural spike vectors. Given a spike vector s and a weight matrix W, the standard dense product y = W @ s visits every element of W. An event-driven implementation only accumulates columns of W where s is non-zero:

y[i] = sum_{j : s[j] != 0} W[i, j] * s[j]

For binary events (BinaryArray), s[j] is either 0 or 1, so the multiplication by s[j] is omitted entirely.

Subclasses must implement __matmul__ and __rmatmul__ to define the event-driven @ operator.

The __array_priority__ is set to 100 so that NumPy/JAX will defer to the operators defined on this class when a standard array appears on the left-hand side of a binary operation.

See also

BinaryArray

Concrete subclass for binary (0/1) event arrays.

Examples

>>> import jax.numpy as jnp
>>> import brainevent as be
>>> arr = be.BinaryArray(jnp.array([True, False, True]))
>>> arr.shape
(3,)
>>> arr.dtype
dtype('bool')
property dtype#

Data type of the underlying array elements.

Returns:

The element data type (e.g. jnp.float32, jnp.bool_).

Return type:

numpy.dtype

property ndim: int#

Number of array dimensions.

Returns:

The number of axes, e.g. 1 for a vector, 2 for a matrix.

Return type:

int

property shape: tuple[int, ...]#

Shape of the underlying array.

Returns:

Dimension sizes, e.g. (n,) for a 1-D vector or (m, n) for a 2-D matrix.

Return type:

tuple of int

property size: int#

Total number of elements in the underlying array.

Returns:

Product of all dimension sizes.

Return type:

int

tree_flatten()[source]#

Flatten this instance for JAX PyTree serialisation.

Returns:

  • children (tuple) – A single-element tuple (value,) containing the dynamic array leaf.

  • aux_data (dict) – An empty dictionary (subclasses may add static metadata here).

classmethod tree_unflatten(aux_data, flat_contents)[source]#

Reconstruct an instance from its PyTree representation.

Parameters:
  • aux_data (dict) – Static metadata produced by tree_flatten.

  • flat_contents (tuple) – Dynamic leaves, i.e. the underlying array.

Returns:

A new instance of the concrete subclass.

Return type:

EventRepresentation

property value: Array | Quantity#

The underlying array data.

Returns:

The raw array stored by this event representation.

Return type:

jax.Array or brainunit.Quantity

with_value(value)[source]#

Create a new instance of the same type with a different value.

Parameters:

value (EventRepresentation | Array | ndarray | Quantity | list | tuple | int | float | bool) – The new underlying array data.

Returns:

A fresh instance of the same concrete class wrapping value.

Return type:

Self

Examples

>>> import jax.numpy as jnp
>>> import brainevent as be
>>> a = be.BinaryArray(jnp.array([True, False]))
>>> b = a.with_value(jnp.array([False, True]))
>>> b.value
Array([False,  True], dtype=bool)