Projections#
Projections are brainpy.state ‘s mechanism for connecting neural populations.
They implement the Communication-Synapse-Output (Comm-Syn-Out) architecture,
which separates connectivity, synaptic dynamics, and output computation into modular components.
This guide provides a comprehensive understanding of projections in brainpy.state.
Overview#
What are Projections?#
A projection connects a presynaptic population to a postsynaptic population through:
Communication (Comm): How spikes propagate through connections
Synapse (Syn): Temporal filtering and synaptic dynamics
Output (Out): How synaptic currents affect postsynaptic neurons
Key benefits:
Modular design (swap components independently)
Biologically realistic (separate connectivity and dynamics)
Efficient (optimized sparse operations)
Flexible (combine components in different ways)
The Comm-Syn-Out Architecture#
import brainstate
import braintools
import saiunit as u
import numpy as np
import brainpy.state
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
brainstate.environ.set(dt=0.1 * u.ms)
Presynaptic Communication Synapse Output Postsynaptic
Population ──► (Connectivity) ──► (Dynamics) ──► (Current) ──► Population
Spikes ──► Weight matrix ──► g(t) ──► I_syn ──► Neurons
Sparse/Dense Expon/Alpha CUBA/COBA
Flow:
Presynaptic spikes arrive
Communication: Spikes propagate through connectivity matrix
Synapse: Temporal dynamics filter the signal
Output: Convert to current/conductance
Postsynaptic neurons receive input
Types of Projections#
BrainPy provides two main projection types:
AlignPostProj
Align synaptic states with postsynaptic neurons
Most common for standard neural networks
Efficient memory layout
AlignPreProj
Align synaptic states with presynaptic neurons
Useful for certain learning rules
Different memory organization
For most use cases, use AlignPostProj.
Communication Layer#
The Communication layer defines how spikes propagate through connections.
Dense Connectivity#
All neurons potentially connected (though weights may be zero).
Use case: Small networks, fully connected layers
# Dense linear transformation
comm = brainstate.nn.Linear(
100, # in_size
50, # out_size
w_init=braintools.init.KaimingNormal(),
b_init=None # No bias for synapses
)
Characteristics:
Memory: O(n_pre × n_post)
Computation: Full matrix multiplication
Best for: Small networks, fully connected architectures
Sparse Connectivity#
Only a subset of connections exist (biologically realistic).
Use case: Large networks, biological connectivity patterns
Event-Based Fixed Probability#
Connect neurons with fixed probability.
# Sparse random connectivity (2% connection probability)
comm = brainstate.nn.EventFixedProb(
1000, # pre_size
800, # post_size
conn_num=0.02, # 2% connectivity
conn_weight=0.5 # Synaptic weight (unitless for event-based)
)
Characteristics:
Memory: O(n_pre × n_post × prob)
Computation: Only active connections
Best for: Large-scale networks, biological models
Event-Based All-to-All#
All neurons connected (but stored sparsely).
# All-to-all sparse (event-driven)
comm = brainstate.nn.AllToAll(
100, # pre_size
100, # post_size
0.3 # Unitless weight
)
Event-Based One-to-One#
One-to-one mapping (same size populations).
size = 100
weight = 1.0
# One-to-one connections
comm = brainstate.nn.OneToOne(
size,
weight # Unitless weight
)
Use case: Feedforward pathways, identity mappings
Synapse Layer#
The Synapse layer defines temporal dynamics of synaptic transmission.
Exponential Synapse#
Single exponential decay (most common).
Dynamics:
Implementation:
# Exponential synapse with 5ms time constant
syn = brainpy.state.Expon(
in_size=100, # Postsynaptic population size
tau=5.0 * u.ms # Decay time constant
)
Characteristics:
Single time constant
Fast computation
Good for most applications
When to use: Default choice for most models
Alpha Synapse#
Dual exponential with rise and decay.
Dynamics:
Implementation:
# Alpha synapse
syn = brainpy.state.Alpha(
in_size=100,
tau=10.0 * u.ms # Characteristic time
)
Characteristics:
Realistic rise time
Smoother response
Slightly slower computation
When to use: When rise time matters, more biological realism
NMDA Synapse#
Voltage-dependent NMDA receptors.
Dynamics:
Implementation:
# NMDA receptor
syn = brainpy.state.BioNMDA(
in_size=100,
T_dur=100.0 * u.ms, # Slow decay
T=2.0 * u.ms, # Fast rise
alpha1=0.5 / u.mM, # Mg²⁺ sensitivity
g_initializer=1.2 * u.mM # Mg²⁺ concentration
)
Characteristics:
Voltage-dependent
Slow kinetics
Important for plasticity
When to use: Long-term potentiation, working memory models
AMPA Synapse#
Fast glutamatergic transmission.
# AMPA receptor (fast excitation)
syn = brainpy.state.AMPA(
in_size=100,
beta=0.5 / u.ms, # Fast decay (~2ms)
)
When to use: Fast excitatory transmission
GABA Synapse#
Inhibitory transmission.
GABAa (fast):
# GABAa receptor (fast inhibition)
syn = brainpy.state.GABAa(
in_size=100,
beta=0.16 / u.ms, # ~6ms decay
)
GABAb (slow):
# GABAb receptor (slow inhibition)
syn = brainpy.state.GABAa(
in_size=100,
T_dur=150.0 * u.ms, # Very slow
T=3.5 * u.ms
)
When to use:
GABAa: Fast inhibition, cortical networks
GABAb: Slow inhibition, rhythm generation
Custom Synapses#
Create custom synaptic dynamics by subclassing Synapse.
class DoubleExpSynapse(brainpy.state.Synapse):
"""Custom synapse with two time constants."""
def __init__(self, size, tau_fast=2 * u.ms, tau_slow=10 * u.ms, **kwargs):
super().__init__(size, **kwargs)
self.tau_fast = tau_fast
self.tau_slow = tau_slow
# State variables
self.g_fast = brainstate.ShortTermState(jnp.zeros(size))
self.g_slow = brainstate.ShortTermState(jnp.zeros(size))
def reset_state(self, batch_size=None):
shape = self.varshape if batch_size is None else (batch_size, *self.varshape)
self.g_fast.value = jnp.zeros(shape)
self.g_slow.value = jnp.zeros(shape)
def update(self, x):
dt = brainstate.environ.get_dt()
# Fast component
dg_fast = -self.g_fast.value / self.tau_fast.to_decimal(u.ms)
self.g_fast.value += dg_fast * dt.to_decimal(u.ms) + x * 0.7
# Slow component
dg_slow = -self.g_slow.value / self.tau_slow.to_decimal(u.ms)
self.g_slow.value += dg_slow * dt.to_decimal(u.ms) + x * 0.3
return self.g_fast.value + self.g_slow.value
Output Layer#
The Output layer defines how synaptic conductance affects neurons.
CUBA (Current-Based)#
Synaptic conductance directly becomes current.
Model:
Implementation:
# Define population sizes
pre_size = 100
post_size = 50
# Define connectivity parameters
conn_num = 0.1
conn_weight = 0.5
comm = brainstate.nn.EventFixedProb(
pre_size, post_size, conn_num, conn_weight
)
Characteristics:
Simple and fast
No voltage dependence
Good for rate-based models
When to use:
Abstract models
When voltage dependence not important
Faster computation needed
COBA (Conductance-Based)#
Synaptic conductance with reversal potential.
Model:
Implementation:
# Excitatory conductance-based
out_exc = brainpy.state.COBA(E=0.0 * u.mV)
# Inhibitory conductance-based
out_inh = brainpy.state.COBA(E=-80.0 * u.mV)
Characteristics:
Voltage-dependent
Biologically realistic
Self-limiting (saturates near reversal)
When to use:
Biologically detailed models
When voltage dependence matters
Shunting inhibition needed
MgBlock (NMDA)#
Voltage-dependent magnesium block for NMDA.
# NMDA with Mg²⁺ block
out_nmda = brainpy.state.MgBlock(
E=0.0 * u.mV,
cc_Mg=1.2 * u.mM,
alpha=0.062 / u.mV,
beta=3.57
)
When to use: NMDA receptors, voltage-dependent plasticity
Complete Projection Examples#
Example 1: Simple Feedforward#
# Create populations
pre = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
post = brainpy.state.LIF(50, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
# Create projection: 100 → 50 neurons
proj = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(
100, # pre_size
50, # post_size
conn_num=0.1, # 10% connectivity
conn_weight=0.5 * u.mS # Weight
),
syn=brainpy.state.Expon(
in_size=50, # Postsynaptic size
tau=5.0 * u.ms
),
out=brainpy.state.CUBA(),
post=post # Postsynaptic population
)
# Initialize
brainstate.nn.init_all_states([pre, post, proj])
# Simulate
def step(t, i, inp):
with brainstate.environ.context(t=t, i=i):
# Update neurons
pre(inp)
# Get presynaptic spikes
pre_spikes = pre.get_spike()
# Update projection
proj(pre_spikes)
post(0.0 * u.nA) # Projection provides input
return pre.get_spike(), post.get_spike()
indices = np.arange(1000)
times = indices * brainstate.environ.get_dt()
inputs = brainstate.random.uniform(30., 50., indices.shape) * u.nA
_ = brainstate.transform.for_loop(step, times, indices, inputs)
Example 2: Excitatory-Inhibitory Network#
class EINetwork(brainstate.nn.Module):
def __init__(self, n_exc=800, n_inh=200):
super().__init__()
# Populations
self.E = brainpy.state.LIF(n_exc, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=15 * u.ms)
self.I = brainpy.state.LIF(n_inh, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
# E → E projection (AMPA, excitatory)
self.E2E = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_exc, n_exc, conn_num=0.02, conn_weight=0.6 * u.mS),
syn=brainpy.state.Expon(n_exc, tau=2. * u.ms),
out=brainpy.state.COBA(E=0.0 * u.mV),
post=self.E
)
# E → I projection (AMPA, excitatory)
self.E2I = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_exc, n_inh, conn_num=0.02, conn_weight=0.6 * u.mS),
syn=brainpy.state.Expon(n_inh, tau=2. * u.ms),
out=brainpy.state.COBA(E=0.0 * u.mV),
post=self.I
)
# I → E projection (GABAa, inhibitory)
self.I2E = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_inh, n_exc, conn_num=0.02, conn_weight=6.7 * u.mS),
syn=brainpy.state.Expon(n_exc, tau=6. * u.ms),
out=brainpy.state.COBA(E=-80.0 * u.mV),
post=self.E
)
# I → I projection (GABAa, inhibitory)
self.I2I = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_inh, n_inh, conn_num=0.02, conn_weight=6.7 * u.mS),
syn=brainpy.state.Expon(n_inh, tau=6. * u.ms),
out=brainpy.state.COBA(E=-80.0 * u.mV),
post=self.I
)
def update(self, i, inp_e, inp_i):
t = brainstate.environ.get_dt() * i
with brainstate.environ.context(t=t, i=i):
# Get spikes BEFORE updating neurons
spk_e = self.E.get_spike()
spk_i = self.I.get_spike()
# Update all projections
self.E2E(spk_e)
self.E2I(spk_e)
self.I2E(spk_i)
self.I2I(spk_i)
# Update neurons (projections provide synaptic input)
self.E(inp_e)
self.I(inp_i)
return spk_e, spk_i
net = EINetwork()
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.update, indices, inputs, inputs)
Example 3: Multi-Timescale Synapses#
Combine AMPA (fast) and NMDA (slow) for realistic excitation.
class DualExcitatory(brainstate.nn.Module):
"""E → E with both AMPA and NMDA."""
def __init__(self, n_pre=100, n_post=100):
super().__init__()
self.post = brainpy.state.LIF(n_post, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
# Fast AMPA component
self.ampa_proj = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
syn=brainpy.state.AMPA(n_post, tau=2.0 * u.ms),
out=brainpy.state.COBA(E=0.0 * u.mV),
post=self.post
)
# Slow NMDA component
self.nmda_proj = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
syn=brainpy.state.NMDA(n_post, tau_decay=100.0 * u.ms, tau_rise=2.0 * u.ms),
out=brainpy.state.MgBlock(E=0.0 * u.mV, cc_Mg=1.2 * u.mM),
post=self.post
)
def update(self, t, i, pre_spikes):
with brainstate.environ.context(t=t, i=i):
# Both projections share same presynaptic spikes
self.ampa_proj(pre_spikes)
self.nmda_proj(pre_spikes)
# Post receives combined input
self.post(0.0 * u.nA)
return self.post.get_spike()
Example 4: Delay Projections#
Add synaptic delays to projections.
# To implement delay, use a separate Delay module
delay_time = 5.0 * u.ms
# Create a network with delay
class DelayedProjection(brainstate.nn.Module):
def __init__(self, pre_size, post_size):
super().__init__()
# Define post_neurons for demonstration
self.post = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
self.delay = self.post.output_delay(delay_time)
# Standard projection
self.proj = brainpy.state.AlignPostProj(
comm=brainstate.nn.EventFixedProb(pre_size, post_size, conn_num=0.1, conn_weight=0.5 * u.mS),
syn=brainpy.state.Expon(post_size, tau=5.0 * u.ms),
out=brainpy.state.CUBA(),
post=self.post
)
def update(self, inp=0. * u.nA):
# Retrieve delayed spikes
delayed_spikes = self.delay()
# Update projection with delayed spikes
self.proj(delayed_spikes)
self.post(inp)
# Store current spikes in delay buffer
self.delay(self.post.get_spike())
def step_run(self, i, inp):
t = brainstate.environ.get_dt() * i
with brainstate.environ.context(t=t, i=i):
# Update post neurons
self.update(inp)
return self.post.get_spike()
net = DelayedProjection(100, 100)
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.step_run, indices, inputs)