Synapse

Synapse#

class brainpy.state.Synapse(in_size, name=None)#

Base class for synapse dynamics.

This class serves as the foundation for all synapse models in the BrainPy framework, providing a common interface for implementing various types of synaptic connectivity and transmission mechanisms. Synapses model the transmission of signals (typically spikes) between neurons, including temporal dynamics, plasticity, and neurotransmitter effects.

All specific synapse implementations (like Expon, Alpha, DualExpon, AMPA, GABAa, etc.) should inherit from this class and implement the required methods for state management and dynamics update.

Parameters:
  • in_size (Size) – Size of the presynaptic input. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g., 100 or (10, 10)).

  • name (str, optional) – Name identifier for the synapse layer. If None, an automatic name will be generated. Useful for debugging and model inspection.

varshape#

Shape of the synaptic state variables, derived from in_size.

Type:

tuple

See also

Expon

Simple first-order exponential decay synapse model

DualExpon

Dual exponential synapse model with separate rise and decay

Alpha

Alpha function synapse model

AMPA

AMPA receptor-mediated excitatory synapse

GABAa

GABAa receptor-mediated inhibitory synapse

Notes

Synaptic Dynamics

Synapses implement temporal filtering of presynaptic signals. The dynamics are typically described by differential equations that govern how synaptic conductance or current evolves over time in response to presynaptic spikes.

State Variables

Synapse models typically maintain state variables (e.g., conductance g, gating variables) as brainstate.HiddenState or brainstate.ShortTermState objects depending on whether they need to be preserved across simulation episodes.

Integration with Neurons

Synapses are commonly used in conjunction with projection layers or connectivity matrices to model synaptic transmission between neuron populations:

  • In feedforward networks: Linear layer → Synapse → Neuron

  • In recurrent networks: Neuron → Linear layer → Synapse → Neuron

Alignment Patterns

Some synapse models inherit from AlignPost to enable event-driven computation where synaptic variables are aligned with postsynaptic neurons. This is particularly efficient for sparse connectivity patterns.

Examples

Creating a Custom Synapse Model

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import braintools
>>>
>>> class SimpleSynapse(brainpy.state.Synapse):
...     def __init__(self, in_size, tau=5.0*u.ms, **kwargs):
...         super().__init__(in_size, **kwargs)
...         self.tau = braintools.init.param(tau, self.varshape)
...         self.g_init = braintools.init.Constant(0.*u.mS)
...
...     def init_state(self, batch_size=None, **kwargs):
...         self.g = brainstate.HiddenState(braintools.init.param(self.g_init, self.varshape, batch_size))
...
...     def reset_state(self, batch_size=None, **kwargs):
...         self.g.value = braintools.init.param(self.g_init, self.varshape, batch_size)
...
...     def update(self, x=None):
...         # Simple exponential decay: dg/dt = -g/tau + x
...         dg = lambda g: -g / self.tau
...         self.g.value = brainstate.nn.exp_euler_step(dg, self.g.value)
...         if x is not None:
...             self.g.value += x
...         return self.g.value

Using Built-in Synapse Models

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import jax
>>>
>>> # Create an exponential synapse
>>> synapse = brainpy.state.Expon(in_size=100, tau=8.0*u.ms)
>>>
>>> # Initialize state
>>> synapse.init_state(batch_size=32)
>>>
>>> # Update with presynaptic spikes
>>> spikes = jax.random.bernoulli(
...     jax.random.PRNGKey(0),
...     p=0.1,
...     shape=(32, 100)
... )
>>> conductance = synapse.update(spikes * 1.0*u.mS)
>>> print(conductance.shape)
(32, 100)

Building a Feedforward Spiking Network

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>>
>>> class SynapticNetwork(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         # Input layer
...         self.input_neurons = brainpy.state.LIF(784, tau=5*u.ms)
...         # First hidden layer with synaptic filtering
...         self.fc1 = brainstate.nn.Linear(784, 256)
...         self.syn1 = brainpy.state.Expon(256, tau=8*u.ms)
...         self.hidden1 = brainpy.state.LIF(256, tau=10*u.ms)
...         # Second hidden layer with AMPA synapse
...         self.fc2 = brainstate.nn.Linear(256, 128)
...         self.syn2 = brainpy.state.AMPA(128)
...         self.hidden2 = brainpy.state.LIF(128, tau=10*u.ms)
...         # Output layer
...         self.fc3 = brainstate.nn.Linear(128, 10)
...         self.output_neurons = brainpy.state.LIF(10, tau=8*u.ms)
...
...     def __call__(self, x):
...         # Input layer
...         spikes0 = self.input_neurons.update(x)
...         # First hidden layer
...         current1 = self.fc1(spikes0)
...         g1 = self.syn1.update(current1)
...         spikes1 = self.hidden1.update(g1)
...         # Second hidden layer
...         current2 = self.fc2(spikes1)
...         g2 = self.syn2.update(current2)
...         spikes2 = self.hidden2.update(g2)
...         # Output layer
...         current3 = self.fc3(spikes2)
...         output_spikes = self.output_neurons.update(current3)
...         return output_spikes

Recurrent Network with Inhibition

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>>
>>> class EINetwork(brainstate.nn.Module):
...     def __init__(self, n_exc=800, n_inh=200):
...         super().__init__()
...         # Excitatory population
...         self.exc_neurons = brainpy.state.LIF(n_exc, tau=10*u.ms)
...         self.exc_syn = brainpy.state.AMPA(n_exc)
...         # Inhibitory population
...         self.inh_neurons = brainpy.state.LIF(n_inh, tau=8*u.ms)
...         self.inh_syn = brainpy.state.GABAa(n_inh)
...         # Connectivity
...         self.exc_to_exc = brainstate.nn.Linear(n_exc, n_exc)
...         self.exc_to_inh = brainstate.nn.Linear(n_exc, n_inh)
...         self.inh_to_exc = brainstate.nn.Linear(n_inh, n_exc)
...         self.inh_to_inh = brainstate.nn.Linear(n_inh, n_inh)
...
...     def __call__(self, ext_input):
...         # Excitatory neurons receive external input and recurrent excitation/inhibition
...         exc_current = (ext_input +
...                       self.exc_to_exc(self.exc_syn.g.value) -
...                       self.inh_to_exc(self.inh_syn.g.value))
...         exc_spikes = self.exc_neurons.update(exc_current)
...         self.exc_syn.update(exc_spikes)
...         # Inhibitory neurons receive excitatory input and recurrent inhibition
...         inh_current = (self.exc_to_inh(self.exc_syn.g.value) -
...                       self.inh_to_inh(self.inh_syn.g.value))
...         inh_spikes = self.inh_neurons.update(inh_current)
...         self.inh_syn.update(inh_spikes)
...         return exc_spikes, inh_spikes

References