Overview#
brainpy.state represents a complete architectural redesign built on top of the brainstate framework. This document explains the design principles and architectural components that make brainpy.state powerful and flexible.
Design Philosophy#
brainpy.state is built around several core principles:
State-Based Programming All dynamical variables are managed as explicit states, enabling automatic differentiation, efficient compilation, and clear data flow.
Modular Composition Complex models are built by composing simple, reusable components. Each component has a well-defined interface and responsibility.
Scientific Accuracy
Integration with saiunit ensures physical correctness and prevents unit-related errors.
Performance by Default JIT compilation and optimization are built into the framework, not an afterthought.
Extensibility Adding new neuron models, synapse types, or learning rules is straightforward and follows clear patterns.
Architectural Layers#
brainpy.state is organized into several layers:
┌─────────────────────────────────────────┐
│ User Models & Networks │ ← Your code
├─────────────────────────────────────────┤
│ BrainPy Components Layer │ ← Neurons, Synapses, Projections
├─────────────────────────────────────────┤
│ BrainState Framework │ ← State management, compilation
├─────────────────────────────────────────┤
│ JAX + XLA Backend │ ← JIT compilation, autodiff
└─────────────────────────────────────────┘
1. JAX + XLA Backend#
The foundation layer provides:
Just-In-Time (JIT) compilation
Automatic differentiation
Hardware acceleration (CPU/GPU/TPU)
Functional transformations (vmap, grad, etc.)
2. BrainState Framework#
Built on JAX, brainstate provides:
State management system
Module composition
Compilation and optimization
Program transformations (for_loop, etc.)
3. BrainPy Components#
High-level neuroscience-specific components:
Neuron models (LIF, ALIF, etc.)
Synapse models (Expon, Alpha, etc.)
Projection architectures
Learning rules and plasticity
4. User Models#
Your custom networks and experiments built using BrainPy components.
State Management System#
The Foundation: brainstate.State#
Everything in brainpy.state revolves around states:
import brainpy.state
import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
# Create a state
voltage = brainstate.State(0.0) # Single value
weights = brainstate.State([[0.1, 0.2], [0.3, 0.4]]) # Matrix
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
States are special containers that:
Track their values across time
Support automatic differentiation
Enable efficient compilation
Handle batching automatically
State Types#
BrainPy uses different state types for different purposes:
ParamState - Trainable Parameters Used for weights, time constants, and other trainable parameters.
class MyNeuron(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.tau = brainstate.ParamState(10.0) # Trainable
self.weight = brainstate.ParamState([[0.1, 0.2]])
ShortTermState - Temporary Variables Used for membrane potentials, synaptic currents, and other dynamics.
class MyNeuron(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
self.V = brainstate.ShortTermState(jnp.zeros(size)) # Dynamic
self.spike = brainstate.ShortTermState(jnp.zeros(size))
State Initialization#
States can be initialized with various strategies:
# Define example size and shape
size = 100 # Number of neurons
shape = (100, 50) # Weight matrix shape
# Constant initialization
V = brainstate.ShortTermState(
braintools.init.Constant(-65.0, unit=u.mV)(size)
)
# Normal distribution
V = brainstate.ShortTermState(
braintools.init.Normal(-65.0, 5.0, unit=u.mV)(size)
)
# Uniform distribution
weights = brainstate.ParamState(
braintools.init.Uniform(0.0, 1.0)(shape)
)
Module System#
Base Class: brainstate.nn.Module#
All BrainPy components inherit from brainstate.nn.Module:
class MyComponent(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
# Initialize states
self.state1 = brainstate.ShortTermState(jnp.zeros(size))
self.param1 = brainstate.ParamState(jnp.ones(size))
def update(self, input):
# Define dynamics
pass
Benefits of Module:
Automatic state registration
Nested module support
State collection and filtering
Serialization support
Module Composition#
Modules can contain other modules:
class Network(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
self.synapse = brainpy.state.Expon(100, tau=5*u.ms)
self.projection = brainpy.state.AlignPostProj(...) # Example - requires more setup
def update(self, input):
# Compose behavior
self.projection(spikes) # Example
self.neurons(input)
Component Architecture#
Neurons#
Neurons model the dynamics of neural populations:
class Neuron(brainstate.nn.Module):
def __init__(self, size, **kwargs):
super().__init__()
# Membrane potential
self.V = brainstate.ShortTermState(jnp.zeros(size))
# Spike output
self.spike = brainstate.ShortTermState(jnp.zeros(size))
def update(self, input_current):
# Update membrane potential
# Generate spikes
pass
Key responsibilities:
Maintain membrane potential
Generate spikes when threshold is crossed
Reset after spiking
Integrate input currents
Synapses#
Synapses model temporal filtering of spike trains:
class Synapse(brainstate.nn.Module):
def __init__(self, size, tau, **kwargs):
super().__init__()
# Synaptic conductance/current
self.g = brainstate.ShortTermState(jnp.zeros(size))
self.tau = tau
def update(self, spike_input):
# Update synaptic variable
# Return filtered output
pass
Key responsibilities:
Filter spike inputs temporally
Model synaptic dynamics (exponential, alpha, etc.)
Provide smooth currents to postsynaptic neurons
Projections: The Comm-Syn-Out Pattern#
Projections connect populations using a three-stage architecture:
Presynaptic Spikes → [Comm] → [Syn] → [Out] → Postsynaptic Neurons
│ │ │
Connectivity │ Current
& Weights Dynamics Injection
Communication (Comm) Handles spike transmission, connectivity, and weights.
# Define population sizes
pre_size = 100
post_size = 50
# Define prob and weight
prob=0.1
weight=0.5
comm = brainstate.nn.EventFixedProb(
pre_size, post_size, prob, weight
)
Synaptic Dynamics (Syn) Temporal filtering of transmitted spikes.
post_size = 50 # Postsynaptic population size
syn = brainpy.state.Expon(post_size, tau=5*u.ms)
Output Mechanism (Out) How synaptic variables affect postsynaptic neurons.
# Current-based output
out = brainpy.state.CUBA()
# Or conductance-based output
out = brainpy.state.COBA(E=0*u.mV)
Complete Projection
# Define postsynaptic neurons
postsynaptic_neurons = brainpy.state.LIF(50, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
# Create complete projection
projection = brainpy.state.AlignPostProj(
comm=comm,
syn=syn,
out=out,
post=postsynaptic_neurons
)
This separation provides:
Clear responsibility boundaries
Easy component swapping
Reusable building blocks
Better testing and debugging
Compilation and Execution#
Time-Stepped Simulation#
BrainPy uses discrete time steps:
# Example: create a simple network
class SimpleNetwork(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
def update(self, t, i):
# Generate constant input current
inp = jnp.ones(100) * 5.0 * u.nA
with brainstate.environ.context(t=t, i=i):
self.neurons(inp)
return self.neurons.get_spike()
network = SimpleNetwork()
brainstate.nn.init_all_states(network)
# Set global time step
brainstate.environ.set(dt=0.1 * u.ms)
# Define simulation duration
times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
# Run simulation
results = brainstate.transform.for_loop(
network.update,
times,
indices,
pbar=brainstate.transform.ProgressBar(10)
)
JIT Compilation#
Functions are compiled for performance:
# Create example input
input_example = jnp.ones(100) * 2.0 * u.nA
@brainstate.transform.jit
def simulate_step(t, i, input_current):
with brainstate.environ.context(t=t, i=i):
return network.update(t, i)
# First call: compile
result = simulate_step(0.0*u.ms, 0, input_example) # Slow (compilation)
# Subsequent calls: fast
result = simulate_step(0.1*u.ms, 1, input_example) # Fast (compiled)
Compilation benefits:
10-100x speedup over Python
Automatic GPU/TPU dispatch
Memory optimization
Fusion of operations
Gradient Computation#
For training, gradients are computed automatically:
# Example: Define mock functions for demonstration
def compute_loss(predictions, targets):
return jnp.mean((predictions.astype(float) - targets) ** 2)
# Mock targets
num_steps = 100
targets = jnp.zeros((num_steps, 100))
def loss_fn():
# Run network for multiple timesteps
def step(t, i):
with brainstate.environ.context(t=t, i=i):
return network.update(t, i)
times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
indices = u.math.arange(times.size)
predictions = brainstate.transform.for_loop(step, times, indices)
return compute_loss(predictions, targets)
# Get trainable parameters
params = network.states(brainstate.ParamState)
# Compute gradients
if len(params) > 0:
optimizer = braintools.optim.Adam(lr=1e-3)
grads, loss = brainstate.transform.grad(
loss_fn,
grad_states=params,
return_value=True
)()
print(f"Loss: {loss}")
# Update parameters with optimizer (if defined)
optimizer.update(grads)
else:
# If no trainable parameters, just compute loss
loss = loss_fn()
print(f"Loss (no trainable params): {loss}")
Loss (no trainable params): 0.0
Physical Units System#
Integration with saiunit#
brainpy.state integrates saiunit for scientific accuracy:
# Define with units
tau = 10 * u.ms
threshold = -50 * u.mV
current = 5 * u.nA
# Units are checked automatically
neuron = brainpy.state.LIF(100, tau=tau, V_th=threshold)
Benefits:
Prevents unit errors (e.g., ms vs s)
Self-documenting code
Automatic unit conversions
Scientific correctness
Unit Operations#
# Arithmetic with units
total_time = 100 * u.ms + 0.5 * u.second # → 600 ms
# Unit conversion
time_in_seconds = (100 * u.ms).to_decimal(u.second) # → 0.1
# Unit checking (automatic in BrainPy operations)
voltage = -65 * u.mV
current = 2 * u.nA
resistance = voltage / current # Automatically gives MΩ
Ecosystem Integration#
brainpy.state integrates tightly with its ecosystem:
braintools#
Utilities and tools:
# Optimizers
optimizer = braintools.optim.Adam(lr=1e-3)
# Initializers
init = braintools.init.KaimingNormal()
# Surrogate gradients
spike_fn = braintools.surrogate.ReluGrad()
# Metrics (example with dummy data)
# pred = jnp.array([0.1, 0.9])
# target = jnp.array([0, 1])
# loss = braintools.metric.cross_entropy(pred, target)
saiunit#
Physical units:
# All standard SI units
time = 10 * u.ms
voltage = -65 * u.mV
current = 2 * u.nA
brainstate#
Core framework (used automatically):
import brainstate
# Module system
class Net(brainstate.nn.Module):
def __init__(self):
super().__init__()
pass
# Compilation
@brainstate.transform.jit
def fn():
return 0
# Transformations
# result = brainstate.transform.for_loop(...)