Synapse#
- class brainpy.state.Synapse(in_size, name=None)#
Base class for synapse dynamics.
This class serves as the foundation for all synapse models in the BrainPy framework, providing a common interface for implementing various types of synaptic connectivity and transmission mechanisms. Synapses model the transmission of signals (typically spikes) between neurons, including temporal dynamics, plasticity, and neurotransmitter effects.
All specific synapse implementations (like Expon, Alpha, DualExpon, AMPA, GABAa, etc.) should inherit from this class and implement the required methods for state management and dynamics update.
- Parameters:
in_size (
Size) – Size of the presynaptic input. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g.,100or(10, 10)).name (
str, optional) – Name identifier for the synapse layer. IfNone, an automatic name will be generated. Useful for debugging and model inspection.
See also
Notes
Synaptic Dynamics
Synapses implement temporal filtering of presynaptic signals. The dynamics are typically described by differential equations that govern how synaptic conductance or current evolves over time in response to presynaptic spikes.
State Variables
Synapse models typically maintain state variables (e.g., conductance
g, gating variables) asbrainstate.HiddenStateorbrainstate.ShortTermStateobjects depending on whether they need to be preserved across simulation episodes.Integration with Neurons
Synapses are commonly used in conjunction with projection layers or connectivity matrices to model synaptic transmission between neuron populations:
In feedforward networks: Linear layer → Synapse → Neuron
In recurrent networks: Neuron → Linear layer → Synapse → Neuron
Alignment Patterns
Some synapse models inherit from
AlignPostto enable event-driven computation where synaptic variables are aligned with postsynaptic neurons. This is particularly efficient for sparse connectivity patterns.Examples
Creating a Custom Synapse Model
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> import braintools >>> >>> class SimpleSynapse(brainpy.state.Synapse): ... def __init__(self, in_size, tau=5.0*u.ms, **kwargs): ... super().__init__(in_size, **kwargs) ... self.tau = braintools.init.param(tau, self.varshape) ... self.g_init = braintools.init.Constant(0.*u.mS) ... ... def init_state(self, batch_size=None, **kwargs): ... self.g = brainstate.HiddenState(braintools.init.param(self.g_init, self.varshape, batch_size)) ... ... def reset_state(self, batch_size=None, **kwargs): ... self.g.value = braintools.init.param(self.g_init, self.varshape, batch_size) ... ... def update(self, x=None): ... # Simple exponential decay: dg/dt = -g/tau + x ... dg = lambda g: -g / self.tau ... self.g.value = brainstate.nn.exp_euler_step(dg, self.g.value) ... if x is not None: ... self.g.value += x ... return self.g.value
Using Built-in Synapse Models
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> import jax >>> >>> # Create an exponential synapse >>> synapse = brainpy.state.Expon(in_size=100, tau=8.0*u.ms) >>> >>> # Initialize state >>> synapse.init_state(batch_size=32) >>> >>> # Update with presynaptic spikes >>> spikes = jax.random.bernoulli( ... jax.random.PRNGKey(0), ... p=0.1, ... shape=(32, 100) ... ) >>> conductance = synapse.update(spikes * 1.0*u.mS) >>> print(conductance.shape) (32, 100)
Building a Feedforward Spiking Network
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> class SynapticNetwork(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... # Input layer ... self.input_neurons = brainpy.state.LIF(784, tau=5*u.ms) ... # First hidden layer with synaptic filtering ... self.fc1 = brainstate.nn.Linear(784, 256) ... self.syn1 = brainpy.state.Expon(256, tau=8*u.ms) ... self.hidden1 = brainpy.state.LIF(256, tau=10*u.ms) ... # Second hidden layer with AMPA synapse ... self.fc2 = brainstate.nn.Linear(256, 128) ... self.syn2 = brainpy.state.AMPA(128) ... self.hidden2 = brainpy.state.LIF(128, tau=10*u.ms) ... # Output layer ... self.fc3 = brainstate.nn.Linear(128, 10) ... self.output_neurons = brainpy.state.LIF(10, tau=8*u.ms) ... ... def __call__(self, x): ... # Input layer ... spikes0 = self.input_neurons.update(x) ... # First hidden layer ... current1 = self.fc1(spikes0) ... g1 = self.syn1.update(current1) ... spikes1 = self.hidden1.update(g1) ... # Second hidden layer ... current2 = self.fc2(spikes1) ... g2 = self.syn2.update(current2) ... spikes2 = self.hidden2.update(g2) ... # Output layer ... current3 = self.fc3(spikes2) ... output_spikes = self.output_neurons.update(current3) ... return output_spikes
Recurrent Network with Inhibition
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> class EINetwork(brainstate.nn.Module): ... def __init__(self, n_exc=800, n_inh=200): ... super().__init__() ... # Excitatory population ... self.exc_neurons = brainpy.state.LIF(n_exc, tau=10*u.ms) ... self.exc_syn = brainpy.state.AMPA(n_exc) ... # Inhibitory population ... self.inh_neurons = brainpy.state.LIF(n_inh, tau=8*u.ms) ... self.inh_syn = brainpy.state.GABAa(n_inh) ... # Connectivity ... self.exc_to_exc = brainstate.nn.Linear(n_exc, n_exc) ... self.exc_to_inh = brainstate.nn.Linear(n_exc, n_inh) ... self.inh_to_exc = brainstate.nn.Linear(n_inh, n_exc) ... self.inh_to_inh = brainstate.nn.Linear(n_inh, n_inh) ... ... def __call__(self, ext_input): ... # Excitatory neurons receive external input and recurrent excitation/inhibition ... exc_current = (ext_input + ... self.exc_to_exc(self.exc_syn.g.value) - ... self.inh_to_exc(self.inh_syn.g.value)) ... exc_spikes = self.exc_neurons.update(exc_current) ... self.exc_syn.update(exc_spikes) ... # Inhibitory neurons receive excitatory input and recurrent inhibition ... inh_current = (self.exc_to_inh(self.exc_syn.g.value) - ... self.inh_to_inh(self.inh_syn.g.value)) ... inh_spikes = self.inh_neurons.update(inh_current) ... self.inh_syn.update(inh_spikes) ... return exc_spikes, inh_spikes
References