EventRepresentation#
- class brainevent.EventRepresentation(value)#
Abstract base class for event-driven array representations.
EventRepresentationwraps an underlying JAX array (orbrainunit.Quantity) and exposes array-like properties (shape,ndim,dtype,size) while requiring subclasses to implement the@(matrix multiplication) operator via__matmul__and__rmatmul__.The class is registered as a JAX PyTree node so that instances are compatible with
jax.jit,jax.grad, and other JAX transformations.- Parameters:
value (
EventRepresentation|Array|ndarray|Quantity|list|tuple|int|float|bool) – The underlying array data. Accepted types includejax.Array,numpy.ndarray,brainunit.Quantity, Python lists/tuples, or anotherEventRepresentation(whose inner value will be extracted).
Notes
Event-driven computation exploits the sparsity of neural spike vectors. Given a spike vector
sand a weight matrixW, the standard dense producty = W @ svisits every element ofW. An event-driven implementation only accumulates columns ofWwheresis non-zero:y[i] = sum_{j : s[j] != 0} W[i, j] * s[j]For binary events (
BinaryArray),s[j]is either 0 or 1, so the multiplication bys[j]is omitted entirely.Subclasses must implement
__matmul__and__rmatmul__to define the event-driven@operator.The
__array_priority__is set to 100 so that NumPy/JAX will defer to the operators defined on this class when a standard array appears on the left-hand side of a binary operation.See also
BinaryArrayConcrete subclass for binary (0/1) event arrays.
Examples
>>> import jax.numpy as jnp >>> import brainevent as be >>> arr = be.BinaryArray(jnp.array([True, False, True])) >>> arr.shape (3,) >>> arr.dtype dtype('bool')
- property dtype#
Data type of the underlying array elements.
- Returns:
The element data type (e.g.
jnp.float32,jnp.bool_).- Return type:
numpy.dtype
- property ndim: int#
Number of array dimensions.
- Returns:
The number of axes, e.g. 1 for a vector, 2 for a matrix.
- Return type:
- property size: int#
Total number of elements in the underlying array.
- Returns:
Product of all dimension sizes.
- Return type:
- tree_flatten()[source]#
Flatten this instance for JAX PyTree serialisation.
- Returns:
children (tuple) – A single-element tuple
(value,)containing the dynamic array leaf.aux_data (dict) – An empty dictionary (subclasses may add static metadata here).
- classmethod tree_unflatten(aux_data, flat_contents)[source]#
Reconstruct an instance from its PyTree representation.
- Parameters:
- Returns:
A new instance of the concrete subclass.
- Return type:
- property value: Array | Quantity#
The underlying array data.
- Returns:
The raw array stored by this event representation.
- Return type:
jax.Array or brainunit.Quantity
- with_value(value)[source]#
Create a new instance of the same type with a different value.
- Parameters:
value (
EventRepresentation|Array|ndarray|Quantity|list|tuple|int|float|bool) – The new underlying array data.- Returns:
A fresh instance of the same concrete class wrapping value.
- Return type:
Self
Examples
>>> import jax.numpy as jnp >>> import brainevent as be >>> a = be.BinaryArray(jnp.array([True, False])) >>> b = a.with_value(jnp.array([False, True])) >>> b.value Array([False, True], dtype=bool)