Quickstart#

This page gets you from zero to a working event-driven matrix multiplication in a couple of minutes. For the why behind it, see What is event-driven computation?.

The core idea#

The brain computes with spikes — sparse, binary events. brainevent exploits that sparsity: wrap a spike vector in BinaryArray, and any matrix multiplication against it skips the zeros and processes only the neurons that fired.

import brainevent
import jax.numpy as jnp

# 1 = spike, 0 = no spike
spikes = brainevent.BinaryArray(jnp.array([1, 0, 1, 0, 1], dtype=jnp.float32))

Multiply by a connectivity matrix#

A BinaryArray multiplies against dense arrays and any of brainevent’s sparse connectivity structures. The operation is event-driven in every case:

import jax.numpy as jnp

# Dense weights (jax/numpy array)
weights = jnp.ones((5, 3))
out_dense = spikes @ weights

# CSR sparse matrix
csr = brainevent.CSR((data, indices, indptr), shape=(5, 3))
out_csr = spikes @ csr

# Just-in-time connectivity — never materialises the full matrix
jitc = brainevent.JITCScalarR(num_pre=5, num_post=3, prob=0.5, weight=0.2, seed=0)
out_jitc = spikes @ jitc

# Fixed fan-out connectivity
fixed = brainevent.FixedPostNumConn(num_pre=5, num_post=3, conn_num=2, weight=0.5, seed=0)
out_fixed = spikes @ fixed

Works inside JAX transformations#

Everything composes with jax.jit, jax.grad, and jax.vmap:

import jax

@jax.jit
def step(spikes, csr):
    return spikes @ csr

out = step(spikes, csr)

Next steps#

Learn step by step

Work through the tutorial notebooks.

Data structures & operators
Solve a specific task

Jump to a how-to recipe.

Working with data structures
Understand the model

Read the conceptual background.

What is event-driven computation?
Look up an API

Browse the reference.

Python API