Overview#

brainpy.state represents a complete architectural redesign built on top of the brainstate framework. This document explains the design principles and architectural components that make brainpy.state powerful and flexible.

Design Philosophy#

brainpy.state is built around several core principles:

State-Based Programming All dynamical variables are managed as explicit states, enabling automatic differentiation, efficient compilation, and clear data flow.

Modular Composition Complex models are built by composing simple, reusable components. Each component has a well-defined interface and responsibility.

Scientific Accuracy Integration with saiunit ensures physical correctness and prevents unit-related errors.

Performance by Default JIT compilation and optimization are built into the framework, not an afterthought.

Extensibility Adding new neuron models, synapse types, or learning rules is straightforward and follows clear patterns.

Architectural Layers#

brainpy.state is organized into several layers:

┌─────────────────────────────────────────┐
│         User Models & Networks          │  ← Your code
├─────────────────────────────────────────┤
│      BrainPy Components Layer           │  ← Neurons, Synapses, Projections
├─────────────────────────────────────────┤
│       BrainState Framework              │  ← State management, compilation
├─────────────────────────────────────────┤
│       JAX + XLA Backend                 │  ← JIT compilation, autodiff
└─────────────────────────────────────────┘

1. JAX + XLA Backend#

The foundation layer provides:

  • Just-In-Time (JIT) compilation

  • Automatic differentiation

  • Hardware acceleration (CPU/GPU/TPU)

  • Functional transformations (vmap, grad, etc.)

2. BrainState Framework#

Built on JAX, brainstate provides:

  • State management system

  • Module composition

  • Compilation and optimization

  • Program transformations (for_loop, etc.)

3. BrainPy Components#

High-level neuroscience-specific components:

  • Neuron models (LIF, ALIF, etc.)

  • Synapse models (Expon, Alpha, etc.)

  • Projection architectures

  • Learning rules and plasticity

4. User Models#

Your custom networks and experiments built using BrainPy components.

State Management System#

The Foundation: brainstate.State#

Everything in brainpy.state revolves around states:

import brainpy.state
import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp

# Create a state
voltage = brainstate.State(0.0)  # Single value
weights = brainstate.State([[0.1, 0.2], [0.3, 0.4]])  # Matrix
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

States are special containers that:

  • Track their values across time

  • Support automatic differentiation

  • Enable efficient compilation

  • Handle batching automatically

State Types#

BrainPy uses different state types for different purposes:

ParamState - Trainable Parameters Used for weights, time constants, and other trainable parameters.

class MyNeuron(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.tau = brainstate.ParamState(10.0)  # Trainable
        self.weight = brainstate.ParamState([[0.1, 0.2]])

ShortTermState - Temporary Variables Used for membrane potentials, synaptic currents, and other dynamics.

class MyNeuron(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.V = brainstate.ShortTermState(jnp.zeros(size))  # Dynamic
        self.spike = brainstate.ShortTermState(jnp.zeros(size))

State Initialization#

States can be initialized with various strategies:

# Define example size and shape
size = 100  # Number of neurons
shape = (100, 50)  # Weight matrix shape

# Constant initialization
V = brainstate.ShortTermState(
    braintools.init.Constant(-65.0, unit=u.mV)(size)
)

# Normal distribution
V = brainstate.ShortTermState(
    braintools.init.Normal(-65.0, 5.0, unit=u.mV)(size)
)

# Uniform distribution
weights = brainstate.ParamState(
    braintools.init.Uniform(0.0, 1.0)(shape)
)

Module System#

Base Class: brainstate.nn.Module#

All BrainPy components inherit from brainstate.nn.Module:

class MyComponent(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        # Initialize states
        self.state1 = brainstate.ShortTermState(jnp.zeros(size))
        self.param1 = brainstate.ParamState(jnp.ones(size))

    def update(self, input):
        # Define dynamics
        pass

Benefits of Module:

  • Automatic state registration

  • Nested module support

  • State collection and filtering

  • Serialization support

Module Composition#

Modules can contain other modules:


class Network(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
        self.synapse = brainpy.state.Expon(100, tau=5*u.ms)
        self.projection = brainpy.state.AlignPostProj(...)  # Example - requires more setup

    def update(self, input):
        # Compose behavior
        self.projection(spikes)  # Example
        self.neurons(input)

Component Architecture#

Neurons#

Neurons model the dynamics of neural populations:

class Neuron(brainstate.nn.Module):
    def __init__(self, size, **kwargs):
        super().__init__()
        # Membrane potential
        self.V = brainstate.ShortTermState(jnp.zeros(size))
        # Spike output
        self.spike = brainstate.ShortTermState(jnp.zeros(size))

    def update(self, input_current):
        # Update membrane potential
        # Generate spikes
        pass

Key responsibilities:

  • Maintain membrane potential

  • Generate spikes when threshold is crossed

  • Reset after spiking

  • Integrate input currents

Synapses#

Synapses model temporal filtering of spike trains:

class Synapse(brainstate.nn.Module):
    def __init__(self, size, tau, **kwargs):
        super().__init__()
        # Synaptic conductance/current
        self.g = brainstate.ShortTermState(jnp.zeros(size))
        self.tau = tau

    def update(self, spike_input):
        # Update synaptic variable
        # Return filtered output
        pass

Key responsibilities:

  • Filter spike inputs temporally

  • Model synaptic dynamics (exponential, alpha, etc.)

  • Provide smooth currents to postsynaptic neurons

Projections: The Comm-Syn-Out Pattern#

Projections connect populations using a three-stage architecture:

Presynaptic Spikes → [Comm] → [Syn] → [Out] → Postsynaptic Neurons
                      │         │       │
                  Connectivity  │    Current
                  & Weights   Dynamics  Injection

Communication (Comm) Handles spike transmission, connectivity, and weights.

# Define population sizes
pre_size = 100
post_size = 50

# Define prob and weight
prob=0.1
weight=0.5

comm = brainstate.nn.EventFixedProb(
    pre_size, post_size, prob, weight
)

Synaptic Dynamics (Syn) Temporal filtering of transmitted spikes.

post_size = 50  # Postsynaptic population size

syn = brainpy.state.Expon(post_size, tau=5*u.ms)

Output Mechanism (Out) How synaptic variables affect postsynaptic neurons.

# Current-based output
out = brainpy.state.CUBA()  

# Or conductance-based output
out = brainpy.state.COBA(E=0*u.mV)

Complete Projection

# Define postsynaptic neurons
postsynaptic_neurons = brainpy.state.LIF(50, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)

# Create complete projection
projection = brainpy.state.AlignPostProj(
    comm=comm,
    syn=syn,
    out=out,
    post=postsynaptic_neurons
)

This separation provides:

  • Clear responsibility boundaries

  • Easy component swapping

  • Reusable building blocks

  • Better testing and debugging

Compilation and Execution#

Time-Stepped Simulation#

BrainPy uses discrete time steps:

# Example: create a simple network
class SimpleNetwork(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.neurons = brainpy.state.LIF(100, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
    
    def update(self, t, i):
        # Generate constant input current
        inp = jnp.ones(100) * 5.0 * u.nA
        with brainstate.environ.context(t=t, i=i):
            self.neurons(inp)
            return self.neurons.get_spike()

network = SimpleNetwork()
brainstate.nn.init_all_states(network)

# Set global time step
brainstate.environ.set(dt=0.1 * u.ms)

# Define simulation duration
times = u.math.arange(0*u.ms, 1000*u.ms, brainstate.environ.get_dt())
indices = u.math.arange(times.size)

# Run simulation
results = brainstate.transform.for_loop(
    network.update,
    times,
    indices,
    pbar=brainstate.transform.ProgressBar(10)
)

JIT Compilation#

Functions are compiled for performance:

# Create example input
input_example = jnp.ones(100) * 2.0 * u.nA

@brainstate.transform.jit
def simulate_step(t, i, input_current):
    with brainstate.environ.context(t=t, i=i):
        return network.update(t, i)

# First call: compile
result = simulate_step(0.0*u.ms, 0, input_example)  # Slow (compilation)

# Subsequent calls: fast
result = simulate_step(0.1*u.ms, 1, input_example)  # Fast (compiled)

Compilation benefits:

  • 10-100x speedup over Python

  • Automatic GPU/TPU dispatch

  • Memory optimization

  • Fusion of operations

Gradient Computation#

For training, gradients are computed automatically:

# Example: Define mock functions for demonstration
def compute_loss(predictions, targets):
    return jnp.mean((predictions.astype(float) - targets) ** 2)

# Mock targets
num_steps = 100
targets = jnp.zeros((num_steps, 100))

def loss_fn():
    # Run network for multiple timesteps
    def step(t, i):
        with brainstate.environ.context(t=t, i=i):
            return network.update(t, i)
    
    times = u.math.arange(0*u.ms, num_steps*brainstate.environ.get_dt(), brainstate.environ.get_dt())
    indices = u.math.arange(times.size)
    predictions = brainstate.transform.for_loop(step, times, indices)
    return compute_loss(predictions, targets)

# Get trainable parameters
params = network.states(brainstate.ParamState)

# Compute gradients
if len(params) > 0:
    optimizer = braintools.optim.Adam(lr=1e-3)
    grads, loss = brainstate.transform.grad(
        loss_fn,
        grad_states=params,
        return_value=True
    )()
    print(f"Loss: {loss}")
    # Update parameters with optimizer (if defined)
    optimizer.update(grads)
else:
    # If no trainable parameters, just compute loss
    loss = loss_fn()
    print(f"Loss (no trainable params): {loss}")
Loss (no trainable params): 0.0

Physical Units System#

Integration with saiunit#

brainpy.state integrates saiunit for scientific accuracy:

# Define with units
tau = 10 * u.ms
threshold = -50 * u.mV
current = 5 * u.nA

# Units are checked automatically
neuron = brainpy.state.LIF(100, tau=tau, V_th=threshold)

Benefits:

  • Prevents unit errors (e.g., ms vs s)

  • Self-documenting code

  • Automatic unit conversions

  • Scientific correctness

Unit Operations#

# Arithmetic with units
total_time = 100 * u.ms + 0.5 * u.second  # → 600 ms

# Unit conversion
time_in_seconds = (100 * u.ms).to_decimal(u.second)  # → 0.1

# Unit checking (automatic in BrainPy operations)
voltage = -65 * u.mV
current = 2 * u.nA
resistance = voltage / current  # Automatically gives MΩ

Ecosystem Integration#

brainpy.state integrates tightly with its ecosystem:

braintools#

Utilities and tools:

# Optimizers
optimizer = braintools.optim.Adam(lr=1e-3)

# Initializers
init = braintools.init.KaimingNormal()

# Surrogate gradients
spike_fn = braintools.surrogate.ReluGrad()

# Metrics (example with dummy data)
# pred = jnp.array([0.1, 0.9])
# target = jnp.array([0, 1])
# loss = braintools.metric.cross_entropy(pred, target)

saiunit#

Physical units:

# All standard SI units
time = 10 * u.ms
voltage = -65 * u.mV
current = 2 * u.nA

brainstate#

Core framework (used automatically):

import brainstate

# Module system
class Net(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        pass

# Compilation
@brainstate.transform.jit
def fn():
    return 0

# Transformations
# result = brainstate.transform.for_loop(...)