brainevent documentation#
BrainEvent provides a set of data structures and algorithms for event-driven computation, which can be used to model the brain dynamics in a more efficient and biologically plausible way.
Installation#
pip install -U brainevent[cpu]
pip install -U brainevent[cuda12]
pip install -U brainevent[cuda13]
pip install -U brainevent[tpu]
What is BrainEvent?#
The brain is fundamentally an event-driven system, where discrete spiking events are the primary units of computation. Traditional dense matrix operations process all array elements, even zeros, leading to significant computational waste in sparse spike-based scenarios where only a small fraction of neurons are active at any given time.
BrainEvent addresses this challenge by:
Processing only active events: Computations skip zero elements, focusing only on neurons that fire spikes
Hardware acceleration: Optimized custom kernels for CPU, GPU, and TPU
Seamless JAX integration: Full support for automatic differentiation, JIT compilation, and vmap
Biologically plausible: Mirrors the sparse, event-driven nature of real neural systems
Core Components#
- 1. Event Representation
BrainEvent provides specialized array types for representing neural events:
BinaryArray: Binary arrays representing spike events (1 = spike, 0 = no spike)
- 2. Sparse Data Structures
Multiple sparse matrix formats optimized for event-driven computation:
COO(Coordinate format): Flexible format for constructing sparse matricesCSR/CSC(Compressed Sparse Row/Column): Fast row/column-oriented operations
- 3. Just-In-Time Connectivity
Generate connectivity matrices on-the-fly without storing full weight matrices (memory-efficient for large networks):
JITCScalarR/JITCScalarC: Scalar (constant) weightsJITCNormalR/JITCNormalC: Normally distributed weightsJITCUniformR/JITCUniformC: Uniformly distributed weights
- 4. Fixed Connectivity Patterns
Specialized structures for biologically realistic fixed-degree connectivity:
FixedPostNumConn: Fixed number of post-synaptic connections per pre-synaptic neuronFixedPreNumConn: Fixed number of pre-synaptic connections per post-synaptic neuron
- 5. Custom Kernel Framework
Extensible system for defining high-performance custom operators:
Numba: CPU-optimized operations with
@numba_kerneldecoratorWarp: NVIDIA GPU operations using Warp language
Pallas: TPU/GPU operations using JAX Pallas
XLA Integration:
XLACustomKernelfor custom XLA operators
- 6. Synaptic Plasticity
Built-in support for learning and plasticity rules:
update_csr_on_binary_pre/update_csr2csc_on_binary_post: CSR-based plasticity updatesupdate_coo_on_binary_pre/update_coo_on_binary_post: COO-based plasticity updatesupdate_dense_on_binary_pre/update_dense_on_binary_post: Dense matrix plasticity
- 7. Unit-Aware Computation
Fully compatible with BrainUnit for physical unit tracking and dimensional analysis.
Quick Start#
Basic Usage
To use event-driven computation, wrap your spike arrays with BinaryArray:
import brainevent
import jax.numpy
# Create spike events (binary array)
spikes = brainevent.BinaryArray(jax.numpy.array([1, 0, 1, 0, 1]))
# Create a sparse connectivity matrix
conn = brainevent.CSR(...)
# Event-driven matrix multiplication
output = spikes @ conn
BrainEvent automatically optimizes computations when BinaryArray is involved,
processing only the active (non-zero) events.
Working with Different Data Structures
import brainevent
import jax.numpy
# Sparse matrices
csr_matrix = brainevent.CSR(...)
coo_matrix = brainevent.COO(...)
# Just-in-time connectivity (memory efficient)
jitc_conn = brainevent.JITCScalarR(num_pre=1000, num_post=1000,
prob=0.1, weight=0.5, seed=0)
# Fixed connectivity patterns
fixed_conn = brainevent.FixedPostNumConn(num_pre=1000, num_post=1000,
conn_num=100, weight=0.5, seed=0)
# Event-driven computations work with all structures
spikes = brainevent.BinaryArray(jax.numpy.array([...]))
output = spikes @ jitc_conn # Only active spikes are processed
See also the ecosystem#
brainevent is one part of our brain modeling ecosystem.