Overview#
brainpy.state introduces a modern, state-based architecture built on top of brainstate. This overview will help you understand the key concepts and design philosophy.
What’s New#
brainpy.state has been completely rewritten to provide:
State-based programming: Built on
brainstatefor efficient state managementModular architecture: Clear separation of concerns (communication, dynamics, outputs)
Physical units: Integration with
saiunitfor scientifically accurate simulationsModern API: Cleaner, more intuitive interfaces
Better performance: Optimized JIT compilation and memory management
Key Architectural Components#
brainpy.state is organized around several core concepts:
1. State Management#
Everything in brainpy.state revolves around states. States are variables that persist across time steps:
brainstate.State: Base state containerbrainstate.ParamState: Trainable parametersbrainstate.ShortTermState: Temporary variables
States enable:
Automatic differentiation for training
Efficient memory management
Batching and parallelization
2. Neurons#
Neurons are the fundamental computational units:
import brainpy.state
import saiunit as u
# Create a population of 100 LIF neurons
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Key neuron models:
brainpy.state.IF: Integrate-and-Firebrainpy.state.LIF: Leaky Integrate-and-Firebrainpy.state.LIFRef: LIF with refractory periodbrainpy.state.ALIF: Adaptive LIF
3. Synapses#
Synapses model the dynamics of neural connections:
# Exponential synapse
synapse = brainpy.state.Expon(100, tau=5*u.ms)
# Alpha synapse (more realistic)
synapse = brainpy.state.Alpha(100, tau=5*u.ms)
Synapse models:
brainpy.state.Expon: Single exponential decaybrainpy.state.Alpha: Double exponential (alpha function)brainpy.state.AMPA: Excitatory receptor dynamicsbrainpy.state.GABAa: Inhibitory receptor dynamics
4. Projections#
Projections connect neural populations:
import brainstate
N_pre=100
N_post=50
prob=0.1
weight=0.5
projection = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(N_pre, N_post, prob, weight),
syn=brainpy.state.Expon.desc(N_post, tau=5*u.ms),
out=brainpy.state.CUBA.desc(),
post=neurons
)
The projection architecture separates:
Communication: How spikes are transmitted (connectivity, weights)
Synaptic dynamics: How synapses respond (temporal filtering)
Output mechanism: How synaptic currents affect neurons (CUBA/COBA)
5. Networks#
Networks combine neurons and projections:
class EINet(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.E = brainpy.state.LIF(800)
self.I = brainpy.state.LIF(200)
self.E2E = brainpy.state.AlignPostProj(...)
self.E2I = brainpy.state.AlignPostProj(...)
# ... more projections
def update(self, input):
# Define network dynamics
pass
Computational Model#
Time-Stepped Simulation#
BrainPy uses discrete time steps for simulation:
# Set simulation time step
brainstate.environ.set(dt=0.1 * u.ms)
# Create a simple neuron for demonstration
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
# Initialize all states
brainstate.nn.init_all_states(neurons)
# Run simulation
def step(t, i):
with brainstate.environ.context(t=t, i=i):
# Provide input current to the neurons
neurons.update(5 * u.nA)
return neurons.get_spike()
times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
results = brainstate.transform.for_loop(step, times, indices)
JIT Compilation#
BrainPy leverages JAX for Just-In-Time compilation:
# Create a simple network for demonstration
network = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(network)
# Define input current
input_current = 5 * u.nA
# JIT-compiled simulation function
@brainstate.transform.jit
def simulate(t, i):
with brainstate.environ.context(t=t, i=i):
network.update(input_current)
return network.get_spike()
# First call compiles, subsequent calls are fast
times = u.math.arange(0*u.ms, 100*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
result = brainstate.transform.for_loop(simulate, times, indices)
Benefits:
Near-C performance
Automatic GPU/TPU dispatch
Optimized memory usage
Physical Units#
brainpy.state integrates saiunit for scientific accuracy:
import saiunit as u
# Define parameters with units
tau = 10 * u.ms
V_threshold = -50 * u.mV
current = 5 * u.nA
# Units are checked automatically
neurons = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)
This prevents unit-related bugs and makes code self-documenting.
Training and Learning#
brainpy.state supports gradient-based training:
import braintools
# Create a simple network for training
net = brainpy.state.LIF(10, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(net)
# Define optimizer
optimizer = braintools.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))
# Prepare dummy data for demonstration
num_steps = 100
inputs = u.math.ones((num_steps,)) * 5 * u.nA
targets = u.math.zeros((num_steps, 10)) # dummy target
# Define loss function
def loss_fn():
def step(t, i, inp):
with brainstate.environ.context(t=t, i=i):
net.update(inp)
return net.spike.value
times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
indices = u.math.arange(times.size)
predictions = brainstate.transform.for_loop(step, times, indices, inputs)
# Simple MSE loss
return u.math.mean((predictions.astype(float) - targets) ** 2)
# Training step
@brainstate.transform.jit
def train_step():
grads, loss_value = brainstate.transform.grad(
loss_fn,
net.states(brainstate.ParamState),
return_value=True
)()
optimizer.update(grads)
return loss_value
Key features:
Surrogate gradients for spiking neurons
Automatic differentiation
Various optimizers (Adam, SGD, etc.)
Ecosystem Components#
brainpy.state is part of a larger ecosystem:
brainstate#
The foundation for state management and compilation:
State-based IR construction
JIT compilation
Program augmentation (batching, etc.)
saiunit#
Physical units system:
SI units support
Automatic unit checking
Unit conversions
braintools#
Utilities and tools:
Optimizers (
braintools.optim)Initialization (
braintools.init)Metrics and losses (
braintools.metric)Surrogate gradients (
braintools.surrogate)Visualization (
braintools.visualize)
Design Philosophy#
brainpy.state follows these principles:
Explicit over implicit: Clear, readable code
Modular composition: Build complex models from simple components
Performance by default: JIT compilation and optimization built-in
Scientific accuracy: Physical units and biologically realistic models
Extensibility: Easy to add custom components
Next Steps#
Now that you understand the core concepts:
Try the 5-minute tutorial to get hands-on experience
Read the detailed BrainPy-style modeling guide
Explore the examples in the repository to learn each component
Check out the examples gallery for real-world models