Quickstart#
This page gets you from zero to a working event-driven matrix multiplication in a couple of minutes. For the why behind it, see What is event-driven computation?.
The core idea#
The brain computes with spikes — sparse, binary events. brainevent exploits that
sparsity: wrap a spike vector in BinaryArray, and any matrix
multiplication against it skips the zeros and processes only the neurons that fired.
import brainevent
import jax.numpy as jnp
# 1 = spike, 0 = no spike
spikes = brainevent.BinaryArray(jnp.array([1, 0, 1, 0, 1], dtype=jnp.float32))
Multiply by a connectivity matrix#
A BinaryArray multiplies against dense arrays and any of brainevent’s sparse
connectivity structures. The operation is event-driven in every case:
import jax.numpy as jnp
# Dense weights (jax/numpy array)
weights = jnp.ones((5, 3))
out_dense = spikes @ weights
# CSR sparse matrix
csr = brainevent.CSR((data, indices, indptr), shape=(5, 3))
out_csr = spikes @ csr
# Just-in-time connectivity — never materialises the full matrix
jitc = brainevent.JITCScalarR(num_pre=5, num_post=3, prob=0.5, weight=0.2, seed=0)
out_jitc = spikes @ jitc
# Fixed fan-out connectivity
fixed = brainevent.FixedPostNumConn(num_pre=5, num_post=3, conn_num=2, weight=0.5, seed=0)
out_fixed = spikes @ fixed
Works inside JAX transformations#
Everything composes with jax.jit, jax.grad, and jax.vmap:
import jax
@jax.jit
def step(spikes, csr):
return spikes @ csr
out = step(spikes, csr)
Next steps#
Learn step by step
Work through the tutorial notebooks.
Solve a specific task
Jump to a how-to recipe.
Understand the model
Read the conceptual background.
Look up an API
Browse the reference.