The mental model#

Who it’s for: everyone, before you go deeper. The four ideas below are the whole framework in miniature — they apply identically whether you simulate biophysical networks or train spiking networks with gradients.

What you’ll learn: (1) state-based programming, (2) physical units, (3) how neurons, synapses, and projections compose, and (4) how to drive a model with brainstate.transform. We close with the “two worlds, one substrate” idea that organizes the rest of the docs.

import brainpy
import brainstate
import braintools
import brainunit as u
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Idea 1 — State-based programming#

JAX is functional: transformations like JIT and autodiff want pure functions with no hidden mutable variables. But a neuron’s membrane potential is mutable state that persists across time steps. brainstate resolves this tension by making every such variable an explicit State, owned by a module.

You rarely create raw states by hand — models declare their own. Your job is to construct the model and then initialize its states with brainstate.nn.init_all_states, which allocates and resets every dynamic variable to a clean starting point.

neuron = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)
brainstate.nn.init_all_states(neuron)   # allocate + reset its States
LIF(
  in_size=(100,),
  out_size=(100,),
  spk_reset=soft,
  spk_fun=ReluGrad(alpha=0.3, width=1.0),
  R=Quantity(1., "ohm"),
  tau=Quantity(10., "ms"),
  V_th=Quantity(-50., "mV"),
  V_rest=Quantity(0., "mV"),
  V_reset=Quantity(0., "mV"),
  V_initializer=Constant(value=0. mV),
  V=HiddenState(
    value=Quantity(~float32[100], "mV")
  )
)

Time-dependent context (the step size dt, the current time t) is supplied with brainstate.environ.context(...), so models read it without you threading it through every call:

with brainstate.environ.context(dt=0.1 * u.ms):
    print('dt =', brainstate.environ.get_dt())
dt = 0.1 ms

Idea 2 — Physical units#

Parameters carry real physical units via brainunit. A time constant is milliseconds, a threshold is millivolts, a current is milliamps. Units are checked at construction time, so a ms-vs-s or mV-vs-V slip is caught immediately instead of silently corrupting a run.

tau = 10. * u.ms          # membrane time constant
V_threshold = -50. * u.mV  # spike threshold
current = 20. * u.mA       # input current

# Units flow straight into model construction.
neuron = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)

Idea 3 — Compose neurons + synapses + projections#

Networks are built by composition. The three building blocks:

  • Neurons (e.g. LIF, LIFRef, ALIF, HH) hold membrane state and emit spikes.

  • Synapses (e.g. Expon, Alpha, AMPA) filter incoming spikes into currents/conductances over time.

  • Projections wire populations together. A projection separates four roles: comm (the connection matrix / connectivity), syn (the synapse dynamics), out (how it drives the target — COBA/CUBA), and post (the target population).

Here two populations are joined by a single AlignPostProj:

pre = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)
post = brainpy.state.LIF(50, tau=10. * u.ms, V_th=-50. * u.mV)

proj = brainpy.state.AlignPostProj(
    comm=brainstate.nn.EventFixedProb(100, 50, conn_num=0.1, conn_weight=0.5 * u.mS),
    syn=brainpy.state.Expon.desc(50, tau=5. * u.ms),
    out=brainpy.state.COBA.desc(E=0. * u.mV),
    post=post,
)

Which projection to reach for — and why aligning synaptic state to the post (or pre) population keeps memory linear — is the subject of the keystone chapter, AlignPre / AlignPost — the keystone.

Idea 4 — Drive with brainstate.transform#

A model runs over many time steps. Never drive it with a bare Python for/while loop: that executes op-by-op, re-traces every step, and forfeits fusion. Instead lower the whole loop into one compiled program with a brainstate.transform primitive:

  • single step / one-shotbrainstate.transform.jit

  • many steps, collect outputsbrainstate.transform.for_loop

  • many steps with an explicit carrybrainstate.transform.scan

  • long rollout under autograd (BPTT)brainstate.transform.checkpointed_for_loop / checkpointed_scan

Running the neuron from Idea 1–2 for 200 ms of constant input:

brainstate.nn.init_all_states(neuron)

def step(t):
    with brainstate.environ.context(t=t):
        neuron(current)
        return neuron.get_spike()

with brainstate.environ.context(dt=0.1 * u.ms):
    times = u.math.arange(0. * u.ms, 200. * u.ms, brainstate.environ.get_dt())
    spikes = brainstate.transform.for_loop(step, times)

print('spikes shape:', spikes.shape)   # [time, neuron]
spikes shape: (2000, 100)

Two worlds, one substrate#

Everything above is shared. The same state-based models, the same units, the same transform-driven loops power both:

  • Brain simulation — run biophysical E/I networks and analyze their dynamics (the 5-minute tour you may have just seen).

  • Brain-inspired computing — because the models are differentiable (neurons accept a surrogate spk_fun, and for_loop/scan are differentiable), you train them with gradients and scale them with linear-memory online learning.

The hinge between the two worlds is the AlignPre/AlignPost projection design: the same alignment that makes simulation memory-efficient is what makes gradient-based and online learning memory-efficient. That is why it is the keystone of the concept spine.

See also#