brainpy.state documentation#
brainpy.state is the point-neuron modeling layer of the BrainX ecosystem. It provides spiking neural network models built on JAX and the brainstate state-management system.
One differentiable substrate, two worlds. brainpy.state is designed so
that the same models serve both brain simulation (biophysical networks,
spike rasters, conductance dynamics) and brain-inspired computing
(surrogate-gradient training, online learning). The bridge is a small set of
ideas — explicit State, physical units, and the distinctive
AlignPre / AlignPost synaptic projection design — that keep memory linear in
the number of neurons while remaining fully differentiable. See
AlignPre / AlignPost — the keystone for the keystone chapter.
Fig. 1 Two worlds on one substrate. The same State-based, unit-aware models —
wired with AlignPre / AlignPost projections — drive both biophysical brain
simulation and gradient-trained brain-inspired computing. Surrogate
gradients and linear-memory online learning are the bridge between them.#
Two model families#
brainpy.state ships two complementary, production-ready model families on a
shared substrate:
BrainPy-style models — high-level, composable neurons (LIF, ALIF, AdEx, HH, Izhikevich, …), synapses (Expon, Alpha, AMPA, GABAa, BioNMDA), projections (
AlignPostProj,DeltaProj, …), readouts, input generators, and short-term plasticity, in the tradition of BrainPy. The idiomatic entry point — start here.NEST-compatible models — JAX re-implementations of NEST simulator neuron, synapse, plasticity (STDP, STP), and device models with NEST parameter names, validated against live NEST with formal tolerance bands. The migration path for NEST users.
All parameters carry physical units via brainunit, and every model compiles through JAX to CPU, GPU, and TPU.
Installation#
pip install -U brainpy.state[cpu]
pip install -U brainpy.state[cuda12]
pip install -U brainpy.state[cuda13]
pip install -U brainpy.state[tpu]
pip install -U BrainX
Quick example#
A small excitatory–inhibitory balanced network, built two ways: with the
high-level BrainPy-style API, and with the NEST-compatible Simulator. See the
5-minute tour for the full walkthrough.
import brainpy
import brainstate
import braintools
import brainunit as u
class EINet(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.n_exc, self.n_inh = 3200, 800
self.num = self.n_exc + self.n_inh
self.N = brainpy.state.LIFRef(
self.num,
V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
tau=20. * u.ms, tau_ref=5. * u.ms,
V_initializer=braintools.init.Normal(-55., 2., unit=u.mV),
)
self.E = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(self.n_exc, self.num,
conn_num=0.02, conn_weight=0.6 * u.mS),
syn=brainpy.state.Expon.desc(self.num, tau=5. * u.ms),
out=brainpy.state.COBA.desc(E=0. * u.mV),
post=self.N,
)
self.I = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(self.n_inh, self.num,
conn_num=0.02, conn_weight=6.7 * u.mS),
syn=brainpy.state.Expon.desc(self.num, tau=10. * u.ms),
out=brainpy.state.COBA.desc(E=-80. * u.mV),
post=self.N,
)
def update(self, t, inp):
with brainstate.environ.context(t=t):
spk = self.N.get_spike() != 0.
self.E(spk[:self.n_exc])
self.I(spk[self.n_exc:])
self.N(inp)
return self.N.get_spike()
net = EINet()
brainstate.nn.init_all_states(net)
with brainstate.environ.context(dt=0.1 * u.ms):
times = u.math.arange(0. * u.ms, 1000. * u.ms, brainstate.environ.get_dt())
spikes = brainstate.transform.for_loop(lambda t: net.update(t, 20. * u.mA), times)
import brainstate
import brainunit as u
from brainpy.state import (Simulator, iaf_psc_alpha, poisson_generator,
spike_recorder, all_to_all, fixed_indegree)
# Brunel (2000) sparse balanced random network, built with the explicit Simulator
order = 400
NE, NI = 4 * order, 1 * order # 1600 excitatory + 400 inhibitory
CE, CI = int(0.1 * NE), int(0.1 * NI) # 10% fixed in-degree
J = 20. * u.pA # excitatory PSC; inhibition is -g * J
neuron = dict(C_m=250. * u.pF, tau_m=20. * u.ms, t_ref=2. * u.ms,
E_L=0. * u.mV, V_reset=0. * u.mV, V_th=20. * u.mV)
sim = Simulator(dt=0.1 * u.ms)
ne = sim.create(iaf_psc_alpha, NE, params=neuron)
ni = sim.create(iaf_psc_alpha, NI, params=neuron)
noise = sim.create(poisson_generator, rate=15000. * u.Hz)
rec = sim.create(spike_recorder)
sim.connect(noise, ne + ni, weight=J, delay=1.5 * u.ms, rule=all_to_all)
sim.connect(ne, ne + ni, weight=J, delay=1.5 * u.ms,
rule=fixed_indegree(CE), comm='sparse', seed=1)
sim.connect(ni, ne + ni, weight=-5. * J, delay=1.5 * u.ms,
rule=fixed_indegree(CI), comm='sparse', seed=2)
sim.connect(ne[:50], rec)
res = sim.simulate(1000. * u.ms)
print(res.rate(rec.segments[0].population)) # excitatory firing rate
Learn more#
See also the ecosystem#
brainpy.state is one part of the BrainX ecosystem:
brainstate — state management for JAX-based brain modeling
brainunit — physical units for neuroscience
brainevent — event-driven sparse operators
braintools — surrogate gradients, analysis, and utilities
braintrace — linear-memory online learning for spiking networks