The mental model#
Who it’s for: everyone, before you go deeper. The four ideas below are the whole framework in miniature — they apply identically whether you simulate biophysical networks or train spiking networks with gradients.
What you’ll learn: (1) state-based programming, (2) physical units, (3) how neurons, synapses, and projections compose, and (4) how to drive a model with brainstate.transform. We close with the “two worlds, one substrate” idea that organizes the rest of the docs.
import brainpy
import brainstate
import braintools
import brainunit as u
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Idea 1 — State-based programming#
JAX is functional: transformations like JIT and autodiff want pure functions with no hidden mutable variables. But a neuron’s membrane potential is mutable state that persists across time steps. brainstate resolves this tension by making every such variable an explicit State, owned by a module.
You rarely create raw states by hand — models declare their own. Your job is to construct the model and then initialize its states with brainstate.nn.init_all_states, which allocates and resets every dynamic variable to a clean starting point.
neuron = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)
brainstate.nn.init_all_states(neuron) # allocate + reset its States
LIF(
in_size=(100,),
out_size=(100,),
spk_reset=soft,
spk_fun=ReluGrad(alpha=0.3, width=1.0),
R=Quantity(1., "ohm"),
tau=Quantity(10., "ms"),
V_th=Quantity(-50., "mV"),
V_rest=Quantity(0., "mV"),
V_reset=Quantity(0., "mV"),
V_initializer=Constant(value=0. mV),
V=HiddenState(
value=Quantity(~float32[100], "mV")
)
)
Time-dependent context (the step size dt, the current time t) is supplied with brainstate.environ.context(...), so models read it without you threading it through every call:
with brainstate.environ.context(dt=0.1 * u.ms):
print('dt =', brainstate.environ.get_dt())
dt = 0.1 ms
Idea 2 — Physical units#
Parameters carry real physical units via brainunit. A time constant is milliseconds, a threshold is millivolts, a current is milliamps. Units are checked at construction time, so a ms-vs-s or mV-vs-V slip is caught immediately instead of silently corrupting a run.
tau = 10. * u.ms # membrane time constant
V_threshold = -50. * u.mV # spike threshold
current = 20. * u.mA # input current
# Units flow straight into model construction.
neuron = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)
Idea 3 — Compose neurons + synapses + projections#
Networks are built by composition. The three building blocks:
Neurons (e.g.
LIF,LIFRef,ALIF,HH) hold membrane state and emit spikes.Synapses (e.g.
Expon,Alpha,AMPA) filter incoming spikes into currents/conductances over time.Projections wire populations together. A projection separates four roles:
comm(the connection matrix / connectivity),syn(the synapse dynamics),out(how it drives the target —COBA/CUBA), andpost(the target population).
Here two populations are joined by a single AlignPostProj:
pre = brainpy.state.LIF(100, tau=10. * u.ms, V_th=-50. * u.mV)
post = brainpy.state.LIF(50, tau=10. * u.ms, V_th=-50. * u.mV)
proj = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(100, 50, conn_num=0.1, conn_weight=0.5 * u.mS),
syn=brainpy.state.Expon.desc(50, tau=5. * u.ms),
out=brainpy.state.COBA.desc(E=0. * u.mV),
post=post,
)
Which projection to reach for — and why aligning synaptic state to the post (or pre) population keeps memory linear — is the subject of the keystone chapter, AlignPre / AlignPost — the keystone.
Idea 4 — Drive with brainstate.transform#
A model runs over many time steps. Never drive it with a bare Python for/while loop: that executes op-by-op, re-traces every step, and forfeits fusion. Instead lower the whole loop into one compiled program with a brainstate.transform primitive:
single step / one-shot →
brainstate.transform.jitmany steps, collect outputs →
brainstate.transform.for_loopmany steps with an explicit carry →
brainstate.transform.scanlong rollout under autograd (BPTT) →
brainstate.transform.checkpointed_for_loop/checkpointed_scan
Running the neuron from Idea 1–2 for 200 ms of constant input:
brainstate.nn.init_all_states(neuron)
def step(t):
with brainstate.environ.context(t=t):
neuron(current)
return neuron.get_spike()
with brainstate.environ.context(dt=0.1 * u.ms):
times = u.math.arange(0. * u.ms, 200. * u.ms, brainstate.environ.get_dt())
spikes = brainstate.transform.for_loop(step, times)
print('spikes shape:', spikes.shape) # [time, neuron]
spikes shape: (2000, 100)
Two worlds, one substrate#
Everything above is shared. The same state-based models, the same units, the same transform-driven loops power both:
Brain simulation — run biophysical E/I networks and analyze their dynamics (the 5-minute tour you may have just seen).
Brain-inspired computing — because the models are differentiable (neurons accept a surrogate
spk_fun, andfor_loop/scanare differentiable), you train them with gradients and scale them with linear-memory online learning.
The hinge between the two worlds is the AlignPre/AlignPost projection design: the same alignment that makes simulation memory-efficient is what makes gradient-based and online learning memory-efficient. That is why it is the keystone of the concept spine.
See also#
AlignPre / AlignPost — the keystone — the keystone chapter the four ideas build toward.
Core Concepts — the full Core Concepts spine and a recommended reading order.
5-minute tour — see these ideas at work in a complete, runnable E/I network.