Overview#

brainpy.state introduces a modern, state-based architecture built on top of brainstate. This overview will help you understand the key concepts and design philosophy.

What’s New#

brainpy.state has been completely rewritten to provide:

  • State-based programming: Built on brainstate for efficient state management

  • Modular architecture: Clear separation of concerns (communication, dynamics, outputs)

  • Physical units: Integration with saiunit for scientifically accurate simulations

  • Modern API: Cleaner, more intuitive interfaces

  • Better performance: Optimized JIT compilation and memory management

Key Architectural Components#

brainpy.state is organized around several core concepts:

1. State Management#

Everything in brainpy.state revolves around states. States are variables that persist across time steps:

  • brainstate.State: Base state container

  • brainstate.ParamState: Trainable parameters

  • brainstate.ShortTermState: Temporary variables

States enable:

  • Automatic differentiation for training

  • Efficient memory management

  • Batching and parallelization

2. Neurons#

Neurons are the fundamental computational units:

import brainpy.state
import saiunit as u

# Create a population of 100 LIF neurons
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Key neuron models:

  • brainpy.state.IF: Integrate-and-Fire

  • brainpy.state.LIF: Leaky Integrate-and-Fire

  • brainpy.state.LIFRef: LIF with refractory period

  • brainpy.state.ALIF: Adaptive LIF

3. Synapses#

Synapses model the dynamics of neural connections:

# Exponential synapse
synapse = brainpy.state.Expon(100, tau=5*u.ms)

# Alpha synapse (more realistic)
synapse = brainpy.state.Alpha(100, tau=5*u.ms)

Synapse models:

  • brainpy.state.Expon: Single exponential decay

  • brainpy.state.Alpha: Double exponential (alpha function)

  • brainpy.state.AMPA: Excitatory receptor dynamics

  • brainpy.state.GABAa: Inhibitory receptor dynamics

4. Projections#

Projections connect neural populations:

import brainstate

N_pre=100
N_post=50
prob=0.1
weight=0.5

projection = brainpy.state.AlignPostProj(
    comm=brainstate.nn.EventFixedProb(N_pre, N_post, prob, weight),
    syn=brainpy.state.Expon.desc(N_post, tau=5*u.ms),
    out=brainpy.state.CUBA.desc(),
    post=neurons
)

The projection architecture separates:

  • Communication: How spikes are transmitted (connectivity, weights)

  • Synaptic dynamics: How synapses respond (temporal filtering)

  • Output mechanism: How synaptic currents affect neurons (CUBA/COBA)

5. Networks#

Networks combine neurons and projections:

class EINet(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.E = brainpy.state.LIF(800)
        self.I = brainpy.state.LIF(200)
        self.E2E = brainpy.state.AlignPostProj(...)
        self.E2I = brainpy.state.AlignPostProj(...)
        # ... more projections

    def update(self, input):
        # Define network dynamics
        pass

Computational Model#

Time-Stepped Simulation#

BrainPy uses discrete time steps for simulation:

# Set simulation time step
brainstate.environ.set(dt=0.1 * u.ms)

# Create a simple neuron for demonstration
neurons = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)

# Initialize all states
brainstate.nn.init_all_states(neurons)

# Run simulation
def step(t, i):
    with brainstate.environ.context(t=t, i=i):
        # Provide input current to the neurons
        neurons.update(5 * u.nA)
        return neurons.get_spike()

times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
results = brainstate.transform.for_loop(step, times, indices)

JIT Compilation#

BrainPy leverages JAX for Just-In-Time compilation:

# Create a simple network for demonstration
network = brainpy.state.LIF(100, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(network)

# Define input current
input_current = 5 * u.nA

# JIT-compiled simulation function
@brainstate.transform.jit
def simulate(t, i):
    with brainstate.environ.context(t=t, i=i):
        network.update(input_current)
        return network.get_spike()

# First call compiles, subsequent calls are fast
times = u.math.arange(0*u.ms, 100*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)
result = brainstate.transform.for_loop(simulate, times, indices)

Benefits:

  • Near-C performance

  • Automatic GPU/TPU dispatch

  • Optimized memory usage

Physical Units#

brainpy.state integrates saiunit for scientific accuracy:

import saiunit as u

# Define parameters with units
tau = 10 * u.ms
V_threshold = -50 * u.mV
current = 5 * u.nA

# Units are checked automatically
neurons = brainpy.state.LIF(100, tau=tau, V_th=V_threshold)

This prevents unit-related bugs and makes code self-documenting.

Training and Learning#

brainpy.state supports gradient-based training:

import braintools

# Create a simple network for training
net = brainpy.state.LIF(10, tau=10*u.ms, V_th=-50*u.mV)
brainstate.nn.init_all_states(net)

# Define optimizer
optimizer = braintools.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))

# Prepare dummy data for demonstration
num_steps = 100
inputs = u.math.ones((num_steps,)) * 5 * u.nA
targets = u.math.zeros((num_steps, 10))  # dummy target

# Define loss function
def loss_fn():
    def step(t, i, inp):
        with brainstate.environ.context(t=t, i=i):
            net.update(inp)
            return net.spike.value
    
    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
    indices = u.math.arange(times.size)
    predictions = brainstate.transform.for_loop(step, times, indices, inputs)
    # Simple MSE loss
    return u.math.mean((predictions.astype(float) - targets) ** 2)

# Training step
@brainstate.transform.jit
def train_step():
    grads, loss_value = brainstate.transform.grad(
        loss_fn,
        net.states(brainstate.ParamState),
        return_value=True
    )()
    optimizer.update(grads)
    return loss_value

Key features:

  • Surrogate gradients for spiking neurons

  • Automatic differentiation

  • Various optimizers (Adam, SGD, etc.)

Ecosystem Components#

brainpy.state is part of a larger ecosystem:

brainstate#

The foundation for state management and compilation:

  • State-based IR construction

  • JIT compilation

  • Program augmentation (batching, etc.)

saiunit#

Physical units system:

  • SI units support

  • Automatic unit checking

  • Unit conversions

braintools#

Utilities and tools:

  • Optimizers (braintools.optim)

  • Initialization (braintools.init)

  • Metrics and losses (braintools.metric)

  • Surrogate gradients (braintools.surrogate)

  • Visualization (braintools.visualize)

Design Philosophy#

brainpy.state follows these principles:

  1. Explicit over implicit: Clear, readable code

  2. Modular composition: Build complex models from simple components

  3. Performance by default: JIT compilation and optimization built-in

  4. Scientific accuracy: Physical units and biologically realistic models

  5. Extensibility: Easy to add custom components

Next Steps#

Now that you understand the core concepts: