Projections#

Projections are brainpy.state ‘s mechanism for connecting neural populations. They implement the Communication-Synapse-Output (Comm-Syn-Out) architecture, which separates connectivity, synaptic dynamics, and output computation into modular components.

This guide provides a comprehensive understanding of projections in brainpy.state.

Overview#

What are Projections?#

A projection connects a presynaptic population to a postsynaptic population through:

  1. Communication (Comm): How spikes propagate through connections

  2. Synapse (Syn): Temporal filtering and synaptic dynamics

  3. Output (Out): How synaptic currents affect postsynaptic neurons

Key benefits:

  • Modular design (swap components independently)

  • Biologically realistic (separate connectivity and dynamics)

  • Efficient (optimized sparse operations)

  • Flexible (combine components in different ways)

The Comm-Syn-Out Architecture#

import brainstate
import braintools
import saiunit as u
import numpy as np

import brainpy.state
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
brainstate.environ.set(dt=0.1 * u.ms)
Presynaptic        Communication         Synapse          Output        Postsynaptic
Population    ──►  (Connectivity)  ──►  (Dynamics)  ──►  (Current) ──►  Population

Spikes        ──►  Weight matrix   ──►  g(t)        ──►  I_syn     ──►  Neurons
                   Sparse/Dense         Expon/Alpha     CUBA/COBA

Flow:

  1. Presynaptic spikes arrive

  2. Communication: Spikes propagate through connectivity matrix

  3. Synapse: Temporal dynamics filter the signal

  4. Output: Convert to current/conductance

  5. Postsynaptic neurons receive input

Types of Projections#

BrainPy provides two main projection types:

AlignPostProj

  • Align synaptic states with postsynaptic neurons

  • Most common for standard neural networks

  • Efficient memory layout

AlignPreProj

  • Align synaptic states with presynaptic neurons

  • Useful for certain learning rules

  • Different memory organization

For most use cases, use AlignPostProj.

Communication Layer#

The Communication layer defines how spikes propagate through connections.

Dense Connectivity#

All neurons potentially connected (though weights may be zero).

Use case: Small networks, fully connected layers

# Dense linear transformation
comm = brainstate.nn.Linear(
    100,  # in_size
    50,  # out_size
    w_init=braintools.init.KaimingNormal(),
    b_init=None  # No bias for synapses
)

Characteristics:

  • Memory: O(n_pre × n_post)

  • Computation: Full matrix multiplication

  • Best for: Small networks, fully connected architectures

Sparse Connectivity#

Only a subset of connections exist (biologically realistic).

Use case: Large networks, biological connectivity patterns

Event-Based Fixed Probability#

Connect neurons with fixed probability.

# Sparse random connectivity (2% connection probability)
comm = brainstate.nn.EventFixedProb(
    1000,  # pre_size
    800,  # post_size
    conn_num=0.02,  # 2% connectivity
    conn_weight=0.5  # Synaptic weight (unitless for event-based)
)

Characteristics:

  • Memory: O(n_pre × n_post × prob)

  • Computation: Only active connections

  • Best for: Large-scale networks, biological models

Event-Based All-to-All#

All neurons connected (but stored sparsely).

# All-to-all sparse (event-driven)
comm = brainstate.nn.AllToAll(
    100,  # pre_size
    100,  # post_size
    0.3  # Unitless weight
)

Event-Based One-to-One#

One-to-one mapping (same size populations).

size = 100
weight = 1.0

# One-to-one connections
comm = brainstate.nn.OneToOne(
    size,
    weight  # Unitless weight
)

Use case: Feedforward pathways, identity mappings

Synapse Layer#

The Synapse layer defines temporal dynamics of synaptic transmission.

Exponential Synapse#

Single exponential decay (most common).

Dynamics:

\[ \tau \frac{dg}{dt} = -g + \sum_k \delta(t - t_k) \]

Implementation:

# Exponential synapse with 5ms time constant
syn = brainpy.state.Expon(
    in_size=100,  # Postsynaptic population size
    tau=5.0 * u.ms  # Decay time constant
)

Characteristics:

  • Single time constant

  • Fast computation

  • Good for most applications

When to use: Default choice for most models

Alpha Synapse#

Dual exponential with rise and decay.

Dynamics:

\[\begin{split} \tau \frac{dg}{dt} = -g + h \\ \tau \frac{dh}{dt} = -h + \sum_k \delta(t - t_k) \end{split}\]

Implementation:

# Alpha synapse
syn = brainpy.state.Alpha(
    in_size=100,
    tau=10.0 * u.ms  # Characteristic time
)

Characteristics:

  • Realistic rise time

  • Smoother response

  • Slightly slower computation

When to use: When rise time matters, more biological realism

NMDA Synapse#

Voltage-dependent NMDA receptors.

Dynamics:

\[ g_{NMDA} = \frac{g}{1 + \eta [Mg^{2+}] e^{-\gamma V}} \]

Implementation:

# NMDA receptor
syn = brainpy.state.BioNMDA(
    in_size=100,
    T_dur=100.0 * u.ms,  # Slow decay
    T=2.0 * u.ms,  # Fast rise
    alpha1=0.5 / u.mM,  # Mg²⁺ sensitivity
    g_initializer=1.2 * u.mM  # Mg²⁺ concentration
)

Characteristics:

  • Voltage-dependent

  • Slow kinetics

  • Important for plasticity

When to use: Long-term potentiation, working memory models

AMPA Synapse#

Fast glutamatergic transmission.

# AMPA receptor (fast excitation)
syn = brainpy.state.AMPA(
    in_size=100,
    beta=0.5 / u.ms,  # Fast decay (~2ms)
)

When to use: Fast excitatory transmission

GABA Synapse#

Inhibitory transmission.

GABAa (fast):

# GABAa receptor (fast inhibition)
syn = brainpy.state.GABAa(
    in_size=100,
    beta=0.16 / u.ms,  # ~6ms decay
)

GABAb (slow):

# GABAb receptor (slow inhibition)
syn = brainpy.state.GABAa(
    in_size=100,
    T_dur=150.0 * u.ms,  # Very slow
    T=3.5 * u.ms
)

When to use:

  • GABAa: Fast inhibition, cortical networks

  • GABAb: Slow inhibition, rhythm generation

Custom Synapses#

Create custom synaptic dynamics by subclassing Synapse.

class DoubleExpSynapse(brainpy.state.Synapse):
    """Custom synapse with two time constants."""

    def __init__(self, size, tau_fast=2 * u.ms, tau_slow=10 * u.ms, **kwargs):
        super().__init__(size, **kwargs)
        self.tau_fast = tau_fast
        self.tau_slow = tau_slow

        # State variables
        self.g_fast = brainstate.ShortTermState(jnp.zeros(size))
        self.g_slow = brainstate.ShortTermState(jnp.zeros(size))

    def reset_state(self, batch_size=None):
        shape = self.varshape if batch_size is None else (batch_size, *self.varshape)
        self.g_fast.value = jnp.zeros(shape)
        self.g_slow.value = jnp.zeros(shape)

    def update(self, x):
        dt = brainstate.environ.get_dt()

        # Fast component
        dg_fast = -self.g_fast.value / self.tau_fast.to_decimal(u.ms)
        self.g_fast.value += dg_fast * dt.to_decimal(u.ms) + x * 0.7

        # Slow component
        dg_slow = -self.g_slow.value / self.tau_slow.to_decimal(u.ms)
        self.g_slow.value += dg_slow * dt.to_decimal(u.ms) + x * 0.3

        return self.g_fast.value + self.g_slow.value

Output Layer#

The Output layer defines how synaptic conductance affects neurons.

CUBA (Current-Based)#

Synaptic conductance directly becomes current.

Model:

\[ I_{syn} = g_{syn} \]

Implementation:

# Define population sizes
pre_size = 100
post_size = 50

# Define connectivity parameters
conn_num = 0.1
conn_weight = 0.5

comm = brainstate.nn.EventFixedProb(
    pre_size, post_size, conn_num, conn_weight
)

Characteristics:

  • Simple and fast

  • No voltage dependence

  • Good for rate-based models

When to use:

  • Abstract models

  • When voltage dependence not important

  • Faster computation needed

COBA (Conductance-Based)#

Synaptic conductance with reversal potential.

Model:

\[ I_{syn} = g_{syn} (E_{syn} - V_{post}) \]

Implementation:

# Excitatory conductance-based
out_exc = brainpy.state.COBA(E=0.0 * u.mV)

# Inhibitory conductance-based
out_inh = brainpy.state.COBA(E=-80.0 * u.mV)

Characteristics:

  • Voltage-dependent

  • Biologically realistic

  • Self-limiting (saturates near reversal)

When to use:

  • Biologically detailed models

  • When voltage dependence matters

  • Shunting inhibition needed

MgBlock (NMDA)#

Voltage-dependent magnesium block for NMDA.

# NMDA with Mg²⁺ block
out_nmda = brainpy.state.MgBlock(
    E=0.0 * u.mV,
    cc_Mg=1.2 * u.mM,
    alpha=0.062 / u.mV,
    beta=3.57
)

When to use: NMDA receptors, voltage-dependent plasticity

Complete Projection Examples#

Example 1: Simple Feedforward#

# Create populations
pre = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
post = brainpy.state.LIF(50, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

# Create projection: 100 → 50 neurons
proj = brainpy.state.AlignPostProj(
    comm=brainstate.nn.EventFixedProb(
        100,  # pre_size
        50,  # post_size
        conn_num=0.1,  # 10% connectivity
        conn_weight=0.5 * u.mS  # Weight
    ),
    syn=brainpy.state.Expon(
        in_size=50,  # Postsynaptic size
        tau=5.0 * u.ms
    ),
    out=brainpy.state.CUBA(),
    post=post  # Postsynaptic population
)

# Initialize
brainstate.nn.init_all_states([pre, post, proj])


# Simulate
def step(t, i, inp):
    with brainstate.environ.context(t=t, i=i):
        # Update neurons
        pre(inp)

        # Get presynaptic spikes
        pre_spikes = pre.get_spike()

        # Update projection
        proj(pre_spikes)

        post(0.0 * u.nA)  # Projection provides input

        return pre.get_spike(), post.get_spike()


indices = np.arange(1000)
times = indices * brainstate.environ.get_dt()
inputs = brainstate.random.uniform(30., 50., indices.shape) * u.nA
_ = brainstate.transform.for_loop(step, times, indices, inputs)

Example 2: Excitatory-Inhibitory Network#

class EINetwork(brainstate.nn.Module):
    def __init__(self, n_exc=800, n_inh=200):
        super().__init__()

        # Populations
        self.E = brainpy.state.LIF(n_exc, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=15 * u.ms)
        self.I = brainpy.state.LIF(n_inh, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

        # E → E projection (AMPA, excitatory)
        self.E2E = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_exc, conn_num=0.02, conn_weight=0.6 * u.mS),
            syn=brainpy.state.Expon(n_exc, tau=2. * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.E
        )

        # E → I projection (AMPA, excitatory)
        self.E2I = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_exc, n_inh, conn_num=0.02, conn_weight=0.6 * u.mS),
            syn=brainpy.state.Expon(n_inh, tau=2. * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.I
        )

        # I → E projection (GABAa, inhibitory)
        self.I2E = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_exc, conn_num=0.02, conn_weight=6.7 * u.mS),
            syn=brainpy.state.Expon(n_exc, tau=6. * u.ms),
            out=brainpy.state.COBA(E=-80.0 * u.mV),
            post=self.E
        )

        # I → I projection (GABAa, inhibitory)
        self.I2I = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_inh, n_inh, conn_num=0.02, conn_weight=6.7 * u.mS),
            syn=brainpy.state.Expon(n_inh, tau=6. * u.ms),
            out=brainpy.state.COBA(E=-80.0 * u.mV),
            post=self.I
        )

    def update(self, i, inp_e, inp_i):
        t = brainstate.environ.get_dt() * i
        with brainstate.environ.context(t=t, i=i):
            # Get spikes BEFORE updating neurons
            spk_e = self.E.get_spike()
            spk_i = self.I.get_spike()

            # Update all projections
            self.E2E(spk_e)
            self.E2I(spk_e)
            self.I2E(spk_i)
            self.I2I(spk_i)

            # Update neurons (projections provide synaptic input)
            self.E(inp_e)
            self.I(inp_i)

            return spk_e, spk_i


net = EINetwork()
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.update, indices, inputs, inputs)

Example 3: Multi-Timescale Synapses#

Combine AMPA (fast) and NMDA (slow) for realistic excitation.

class DualExcitatory(brainstate.nn.Module):
    """E → E with both AMPA and NMDA."""

    def __init__(self, n_pre=100, n_post=100):
        super().__init__()

        self.post = brainpy.state.LIF(n_post, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)

        # Fast AMPA component
        self.ampa_proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
            syn=brainpy.state.AMPA(n_post, tau=2.0 * u.ms),
            out=brainpy.state.COBA(E=0.0 * u.mV),
            post=self.post
        )

        # Slow NMDA component
        self.nmda_proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),
            syn=brainpy.state.NMDA(n_post, tau_decay=100.0 * u.ms, tau_rise=2.0 * u.ms),
            out=brainpy.state.MgBlock(E=0.0 * u.mV, cc_Mg=1.2 * u.mM),
            post=self.post
        )

    def update(self, t, i, pre_spikes):
        with brainstate.environ.context(t=t, i=i):
            # Both projections share same presynaptic spikes
            self.ampa_proj(pre_spikes)
            self.nmda_proj(pre_spikes)

            # Post receives combined input
            self.post(0.0 * u.nA)

            return self.post.get_spike()

Example 4: Delay Projections#

Add synaptic delays to projections.

# To implement delay, use a separate Delay module
delay_time = 5.0 * u.ms


# Create a network with delay
class DelayedProjection(brainstate.nn.Module):
    def __init__(self, pre_size, post_size):
        super().__init__()

        # Define post_neurons for demonstration
        self.post = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)
        self.delay = self.post.output_delay(delay_time)

        # Standard projection
        self.proj = brainpy.state.AlignPostProj(
            comm=brainstate.nn.EventFixedProb(pre_size, post_size, conn_num=0.1, conn_weight=0.5 * u.mS),
            syn=brainpy.state.Expon(post_size, tau=5.0 * u.ms),
            out=brainpy.state.CUBA(),
            post=self.post
        )

    def update(self, inp=0. * u.nA):
        # Retrieve delayed spikes
        delayed_spikes = self.delay()
        # Update projection with delayed spikes
        self.proj(delayed_spikes)
        self.post(inp)
        # Store current spikes in delay buffer
        self.delay(self.post.get_spike())

    def step_run(self, i, inp):
        t = brainstate.environ.get_dt() * i
        with brainstate.environ.context(t=t, i=i):
            # Update post neurons
            self.update(inp)
            return self.post.get_spike()


net = DelayedProjection(100, 100)
brainstate.nn.init_all_states(net)
_ = brainstate.transform.for_loop(net.step_run, indices, inputs)