brainpy.state documentation#
brainpy.state provides spiking neural network models built on JAX and the brainstate state-management system. It is the point-neuron modeling layer of the BrainX ecosystem.
API maturity
brainpy.state ships two model families with different maturity levels:
Stable — Base classes (
Dynamics,Neuron,Synapse) and BrainPy-style models (45+ neurons, synapses, projections, readouts, input generators). Public API is stable for the 0.0.x series and recommended for production use and surrogate-gradient training.Experimental — In Development — NEST-compatible models (119+ neurons, synapses, plasticity, devices). These are under active development; parameter names, defaults, and numerical behavior may change without notice. Use them for exploration and validation, but pin your dependency version and expect breaking changes.
See the NEST-style status page for what is currently available and what users should not rely on yet.
What’s in the library#
Base classes —
Dynamics,Neuron,Synapse: the abstract foundation every model inherits from.BrainPy-style models (Stable, 45+) — high-level, composable neurons (LIF, ALIF, AdEx, HH, Izhikevich, …), synapses (Expon, Alpha, AMPA, GABA, NMDA), projections, readouts, input generators, and short-term plasticity, designed in the tradition of BrainPy.
NEST-compatible models (Experimental, 119+) — JAX re-implementations of NEST simulator neuron, synapse, plasticity (STDP, STP), and device models with NEST-compatible parameter names.
All parameters carry physical units via saiunit; BrainPy-style neurons support surrogate-gradient training out of the box.
Installation#
pip install -U brainpy.state[cpu]
pip install -U brainpy.state[cuda12]
pip install -U brainpy.state[cuda13]
pip install -U brainpy.state[tpu]
pip install -U BrainX
Quick Example#
A small excitatory–inhibitory balanced network using the Stable BrainPy-style API. See the 5-minute tutorial for the full walkthrough.
import brainpy.state
import brainstate
import braintools
import saiunit as u
class EINet(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.n_exc, self.n_inh = 3200, 800
self.num = self.n_exc + self.n_inh
self.N = brainpy.state.LIFRef(
self.num,
V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
tau=20. * u.ms, tau_ref=5. * u.ms,
V_initializer=braintools.init.Normal(-55., 2., unit=u.mV),
)
self.E = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(self.n_exc, self.num,
conn_num=0.02, conn_weight=0.6 * u.mS),
syn=brainpy.state.Expon.desc(self.num, tau=5. * u.ms),
out=brainpy.state.COBA.desc(E=0. * u.mV),
post=self.N,
)
self.I = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(self.n_inh, self.num,
conn_num=0.02, conn_weight=6.7 * u.mS),
syn=brainpy.state.Expon.desc(self.num, tau=10. * u.ms),
out=brainpy.state.COBA.desc(E=-80. * u.mV),
post=self.N,
)
def update(self, t, inp):
with brainstate.environ.context(t=t):
spk = self.N.get_spike() != 0.
self.E(spk[:self.n_exc])
self.I(spk[self.n_exc:])
self.N(inp)
return self.N.get_spike()
Learn more#
See also the ecosystem#
brainpy.state is one part of the BrainX ecosystem:
brainstate — state management for JAX-based brain modeling
saiunit — physical units for neuroscience
brainevent — event-driven sparse operators
braintools — surrogate gradients, analysis, and utilities