# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Optional
import brainstate
import braintools
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainpy_state._base import Synapse
__all__ = [
'STP', 'STD',
]
class STP(Synapse):
r"""
Synapse with short-term plasticity.
This class implements a synapse model with short-term plasticity (STP), which captures
activity-dependent changes in synaptic efficacy that occur over milliseconds to seconds.
The model simultaneously accounts for both short-term facilitation and depression
based on the formulation by Tsodyks & Markram (1998).
The model is characterized by the following equations:
$$
\frac{du}{dt} = -\frac{u}{\tau_f} + U \cdot (1 - u) \cdot \delta(t - t_{spike})
$$
$$
\frac{dx}{dt} = \frac{1 - x}{\tau_d} - u \cdot x \cdot \delta(t - t_{spike})
$$
$$
g_{syn} = u \cdot x
$$
where:
- $u$ represents the utilization of synaptic efficacy (facilitation variable)
- $x$ represents the available synaptic resources (depression variable)
- $\tau_f$ is the facilitation time constant
- $\tau_d$ is the depression time constant
- $U$ is the baseline utilization parameter
- $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
- $g_{syn}$ is the effective synaptic conductance
Parameters
----------
in_size : Size
Size of the input.
name : str, optional
Name of the synapse instance.
U : ArrayLike, default=0.15
Baseline utilization parameter (fraction of resources used per action potential).
tau_f : ArrayLike, default=1500.*u.ms
Time constant of short-term facilitation in milliseconds.
tau_d : ArrayLike, default=200.*u.ms
Time constant of short-term depression (recovery of synaptic resources) in milliseconds.
Attributes
----------
u : HiddenState
Utilization of synaptic efficacy (facilitation variable).
x : HiddenState
Available synaptic resources (depression variable).
See Also
--------
STD : Short-term depression only model.
Notes
-----
- Larger values of tau_f produce stronger facilitation effects.
- Larger values of tau_d lead to slower recovery from depression.
- The parameter U controls the initial release probability [1]_.
- The effective synaptic strength is the product of u and x.
- For a comprehensive treatment of short-term plasticity dynamics, see [2]_.
References
----------
.. [1] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
pyramidal neurons depends on neurotransmitter release probability.
Proceedings of the National Academy of Sciences, 94(2), 719-723.
.. [2] Tsodyks, M., Pawelzik, K., & Markram, H. (1998). Neural networks with dynamic
synapses. Neural computation, 10(4), 821-835.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> # Create an STP synapse with facilitation-dominant parameters
>>> stp = brainpy.state.STP(100, U=0.1, tau_f=1500.*u.ms, tau_d=200.*u.ms)
>>> stp.init_state(batch_size=1)
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
name: Optional[str] = None,
U: ArrayLike = 0.15,
tau_f: ArrayLike = 1500. * u.ms,
tau_d: ArrayLike = 200. * u.ms,
):
super().__init__(name=name, in_size=in_size)
# parameters
self.tau_f = braintools.init.param(tau_f, self.varshape)
self.tau_d = braintools.init.param(tau_d, self.varshape)
self.U = braintools.init.param(U, self.varshape)
[docs]
def init_state(self, batch_size: int = None, **kwargs):
self.x = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(1.), self.varshape, batch_size)
)
self.u = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(self.U), self.varshape, batch_size)
)
[docs]
def reset_state(self, batch_size: int = None, **kwargs):
self.x.value = braintools.init.param(braintools.init.Constant(1.), self.varshape, batch_size)
self.u.value = braintools.init.param(braintools.init.Constant(self.U), self.varshape, batch_size)
def update(self, pre_spike):
u = brainstate.nn.exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
x = brainstate.nn.exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
# --- original code:
# if pre_spike.dtype == jax.numpy.bool_:
# u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
# x = bm.where(pre_spike, x - u * self.x, x)
# else:
# u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
# x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
# --- simplified code:
u = u + pre_spike * self.U * (1 - self.u.value)
x = x - pre_spike * u * self.x.value
self.u.value = u
self.x.value = x
return u * x * pre_spike
class STD(Synapse):
r"""
Synapse with short-term depression.
This class implements a synapse model with short-term depression (STD), which captures
activity-dependent reduction in synaptic efficacy, typically caused by depletion of
neurotransmitter vesicles following repeated stimulation.
The model is characterized by the following equation:
$$
\frac{dx}{dt} = \frac{1 - x}{\tau} - U \cdot x \cdot \delta(t - t_{spike})
$$
$$
g_{syn} = x
$$
where:
- $x$ represents the available synaptic resources (depression variable)
- $\tau$ is the depression recovery time constant
- $U$ is the utilization parameter (fraction of resources depleted per spike)
- $\delta(t - t_{spike})$ is the Dirac delta function representing presynaptic spikes
- $g_{syn}$ is the effective synaptic conductance
Parameters
----------
in_size : Size
Size of the input.
name : str, optional
Name of the synapse instance.
tau : ArrayLike, default=200.*u.ms
Time constant governing recovery of synaptic resources in milliseconds.
U : ArrayLike, default=0.07
Utilization parameter (fraction of resources used per action potential).
Attributes
----------
x : HiddenState
Available synaptic resources (depression variable).
See Also
--------
STP : Full short-term plasticity model with facilitation and depression.
Notes
-----
- Larger values of tau lead to slower recovery from depression [1]_.
- Larger values of U cause stronger depression with each spike.
- This model is a simplified version of the STP model that only includes depression [2]_.
References
----------
.. [1] Abbott, L. F., Varela, J. A., Sen, K., & Nelson, S. B. (1997). Synaptic
depression and cortical gain control. Science, 275(5297), 220-224.
.. [2] Tsodyks, M. V., & Markram, H. (1997). The neural code between neocortical
pyramidal neurons depends on neurotransmitter release probability.
Proceedings of the National Academy of Sciences, 94(2), 719-723.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> # Create an STD synapse
>>> std = brainpy.state.STD(100, tau=200.*u.ms, U=0.07)
>>> std.init_state(batch_size=1)
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
name: Optional[str] = None,
tau: ArrayLike = 200. * u.ms,
U: ArrayLike = 0.07,
):
super().__init__(name=name, in_size=in_size)
# parameters
self.tau = braintools.init.param(tau, self.varshape)
self.U = braintools.init.param(U, self.varshape)
[docs]
def init_state(self, batch_size: int = None, **kwargs):
self.x = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(1.), self.varshape, batch_size)
)
[docs]
def reset_state(self, batch_size: int = None, **kwargs):
self.x.value = braintools.init.param(braintools.init.Constant(1.), self.varshape, batch_size)
def update(self, pre_spike):
x = brainstate.nn.exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
# --- original code:
# self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
# --- simplified code:
self.x.value = x - pre_spike * self.U * self.x.value
return self.x.value * pre_spike