# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
import brainstate
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike
from .static_synapse import static_synapse
__all__ = [
'stdp_synapse',
]
_UNSET = object()
_STDP_EPS = 1.0e-6
class stdp_synapse(static_synapse):
r"""NEST-compatible ``stdp_synapse`` connection model.
``stdp_synapse`` implements pair-based spike-timing dependent plasticity (STDP)
following Guetig et al. (2003) and the NEST reference implementation from
``models/stdp_synapse.h``. The model supports asymmetric Hebbian learning with
configurable weight-dependent potentiation and depression exponents.
The synapse maintains three dynamic state variables per connection:
- ``weight``: current synaptic efficacy (plastic, updated on each presynaptic spike)
- ``Kplus``: presynaptic eligibility trace (exponentially decays with time constant ``tau_plus``)
- ``t_lastspike``: timestamp of the most recent presynaptic spike
Postsynaptic spike history is stored internally with time constant ``tau_minus``.
In NEST, ``tau_minus`` is a postsynaptic neuron parameter (``ArchivingNode``);
here it is stored on the synapse for standalone compatibility, enabling STDP
simulation without requiring postsynaptic neurons to implement archiving APIs.
**1. Mathematical Model**
State Variables
---------------
- ``w``: Synaptic weight (plastic, bounded to :math:`[0, W_{\max}]` or :math:`[W_{\max}, 0]`)
- ``K^+``: Presynaptic eligibility trace (decays with :math:`\tau_+`)
- ``K^-``: Postsynaptic eligibility trace (decays with :math:`\tau_-`)
**Continuous-time dynamics (between spikes):**
.. math::
\frac{dK^+}{dt} = -\frac{K^+}{\tau_+}
\frac{dK^-}{dt} = -\frac{dK^-}{\tau_-}
**Upon presynaptic spike at time** :math:`t_{\text{pre}}`:
Let :math:`d` denote the dendritic (synaptic) delay. The NEST ``stdp_synapse::send``
method performs the following sequence:
**Step 1: Facilitation (potentiation) from past postsynaptic spikes**
For each postsynaptic spike :math:`t_{\text{post}}` in the window
:math:`(t_{\text{last}} - d,\, t_{\text{pre}} - d]`, where
:math:`t_{\text{last}}` is the timestamp of the previous presynaptic spike:
.. math::
\Delta t = (t_{\text{post}} + d) - t_{\text{last}}
K^+_{\text{eff}} = K^+ \cdot e^{(t_{\text{last}} - (t_{\text{post}} + d)) / \tau_+}
\hat{w} \leftarrow \hat{w} + \lambda (1 - \hat{w})^{\mu_+} K^+_{\text{eff}}
where :math:`\hat{w} = w / W_{\max}` is the normalized weight.
**Step 2: Depression from current presynaptic spike**
Retrieve postsynaptic trace :math:`K^-` at time :math:`t_{\text{pre}} - d`:
.. math::
K^-_{\text{eff}} = K^-(t_{\text{pre}} - d)
\hat{w} \leftarrow \hat{w} - \alpha \lambda \hat{w}^{\mu_-} K^-_{\text{eff}}
**Step 3: Deliver spike event**
Send event with updated weight ``w`` to the postsynaptic receiver.
**Step 4: Update presynaptic trace**
.. math::
K^+ \leftarrow K^+ \cdot e^{(t_{\text{last}} - t_{\text{pre}}) / \tau_+} + 1
t_{\text{last}} \leftarrow t_{\text{pre}}
**Upon postsynaptic spike at time** :math:`t_{\text{post}}`:
.. math::
K^- \\leftarrow K^- \\cdot e^{(t_{\\text{last\\_post}} - t_{\\text{post}}) / \\tau_-} + 1
t_{\\text{last\\_post}} \\leftarrow t_{\\text{post}}
**Weight Update Functions:**
Potentiation (post-before-pre, :math:`\Delta t > 0`):
.. math::
\hat{w} \leftarrow \hat{w} + \lambda (1 - \hat{w})^{\mu_+} K^+
Depression (pre-before-post, :math:`\Delta t < 0`):
.. math::
\hat{w} \leftarrow \hat{w} - \alpha \lambda \hat{w}^{\mu_-} K^-
Final weight is clipped to :math:`[0, W_{\max}]` for positive weights, or
:math:`[W_{\max}, 0]` for negative weights (inhibitory synapses).
**2. Update Ordering and NEST Compatibility**
This implementation replicates the exact update sequence from NEST
``models/stdp_synapse.h::send()``:
1. Query postsynaptic spike history in window :math:`(t_{\text{last}} - d,\, t_{\text{pre}} - d]`
2. Apply facilitation for each retrieved postsynaptic spike
3. Compute postsynaptic trace :math:`K^-` at :math:`t_{\text{pre}} - d`
4. Apply depression based on :math:`K^-`
5. Schedule weighted spike event for delivery after delay :math:`d`
6. Update presynaptic trace :math:`K^+` and timestamp :math:`t_{\text{last}}`
**3. Event Timing Semantics**
As in NEST, this model uses **on-grid spike timestamps** and ignores precise
sub-step offsets. Spike times are discretized to simulation time steps:
- Presynaptic spike detected at step ``n`` → stamped at :math:`t_{\text{spike}} = t + dt`
- Postsynaptic spike recorded at step ``n`` → stamped at :math:`t_{\text{spike}} = t + dt`
- Inter-spike intervals computed from discrete timestamps
This differs from continuous-time STDP but matches NEST's default behavior.
**4. Stability Constraints and Computational Implications**
**Parameter Constraints:**
- :math:`\tau_+ > 0`, :math:`\tau_- > 0` (time constants must be positive)
- :math:`\lambda > 0` (learning rate; negative values would invert plasticity)
- :math:`\alpha \geq 0` (depression scaling; typically 1.0)
- :math:`\mu_+ \geq 0`, :math:`\mu_- \geq 0` (exponents; 1.0 for linear, 0.0 for additive)
- :math:`W_{\max} \neq 0` and :math:`\text{sign}(w) = \text{sign}(W_{\max})`
- :math:`K^+ \geq 0` (trace must be non-negative)
Numerical Considerations
------------------------
- All state variables stored as Python ``float`` (``float64`` precision)
- Exponential decays computed using ``math.exp()`` for numerical stability
- Power functions use ``math.pow()`` (may degrade for large exponents)
- Per-spike cost: :math:`O(N_{\text{post}})` where :math:`N_{\text{post}}` is
the number of postsynaptic spikes in the facilitation window
**Behavioral Regimes:**
- **Symmetric STDP** (:math:`\alpha = 1`, :math:`\mu_+ = \mu_- = 1`):
Classical pair-based rule (Song et al., 2000)
- **Additive STDP** (:math:`\mu_+ = \mu_- = 0`):
Weight-independent updates (van Rossum et al., 2000)
- **Multiplicative STDP** (:math:`\mu_+ = \mu_- = 1`):
Soft bounds stabilize weight distributions (Guetig et al., 2003)
- **Asymmetric depression** (:math:`\alpha > 1`):
Stronger depression relative to potentiation
Failure Modes
-------------
- ``weight`` and ``Wmax`` must have the same sign; otherwise ``ValueError`` on init or set
- ``Kplus`` must be non-negative; otherwise ``ValueError`` on init or set
- Postsynaptic spike history grows unbounded if not cleared; use
``clear_post_history()`` periodically for long simulations
Parameters
----------
weight : float, array-like, or Quantity, optional
Initial synaptic weight :math:`w`. Scalar float, dimensionless or with
receiver-specific units (e.g., pA, nS). Must have the same sign as ``Wmax``.
Default: ``1.0`` (dimensionless).
delay : float, array-like, or Quantity, optional
Synaptic transmission delay :math:`d` in milliseconds. Must be ``> 0``.
Quantized to integer time steps per :class:`static_synapse` conventions.
Default: ``1.0 * u.ms``.
receptor_type : int, optional
Receptor port identifier on the postsynaptic neuron. Non-negative integer
specifying which input channel receives the event. Default: ``0``.
tau_plus : float, array-like, or Quantity, optional
Presynaptic trace time constant :math:`\tau_+` in milliseconds. Must be ``> 0``.
Controls the width of the potentiation (post-before-pre) window.
Default: ``20.0 * u.ms``.
tau_minus : float, array-like, or Quantity, optional
Postsynaptic trace time constant :math:`\tau_-` in milliseconds. Must be ``> 0``.
In NEST, this is a postsynaptic neuron parameter; here it is stored on the
synapse for standalone compatibility. Controls the width of the depression
(pre-before-post) window. Default: ``20.0 * u.ms``.
lambda_ : float, array-like, or Quantity, optional
Learning rate parameter :math:`\lambda` (dimensionless). Scales both
potentiation and depression updates. Typical values: 0.001–0.1.
Default: ``0.01``.
alpha : float, array-like, or Quantity, optional
Asymmetry parameter :math:`\alpha` (dimensionless). Scales depression
relative to potentiation. :math:`\alpha = 1.0` yields symmetric STDP;
:math:`\alpha > 1.0` strengthens depression. Default: ``1.0``.
mu_plus : float, array-like, or Quantity, optional
Potentiation exponent :math:`\mu_+` (dimensionless). Controls weight
dependence of potentiation. :math:`\mu_+ = 0`: additive; :math:`\mu_+ = 1`:
multiplicative (soft upper bound). Default: ``1.0``.
mu_minus : float, array-like, or Quantity, optional
Depression exponent :math:`\mu_-` (dimensionless). Controls weight
dependence of depression. :math:`\mu_- = 0`: additive; :math:`\mu_- = 1`:
multiplicative (soft lower bound). Default: ``1.0``.
Wmax : float, array-like, or Quantity, optional
Maximum weight bound :math:`W_{\max}` (same units as ``weight``). Weights are
clipped to :math:`[0, W_{\max}]` for excitatory synapses or :math:`[W_{\max}, 0]`
for inhibitory synapses. Must have the same sign as ``weight``.
Default: ``100.0`` (dimensionless).
Kplus : float, array-like, or Quantity, optional
Initial presynaptic trace value :math:`K^+` (dimensionless). Must be
non-negative. Typically initialized to ``0.0`` (no presynaptic history).
Default: ``0.0``.
post : Dynamics, optional
Default postsynaptic receiver object. If provided, :meth:`send` and
:meth:`update` will target this receiver unless overridden. Must implement
either ``add_delta_input`` or ``add_current_input`` methods.
Default: ``None`` (must provide receiver explicitly in method calls).
name : str, optional
Unique identifier for this synapse instance. Default: auto-generated.
Parameter Mapping
-----------------
NEST ``stdp_synapse`` parameters map to this implementation as follows:
================== ==================== ========================================
NEST Parameter brainpy.state Param Notes
================== ==================== ========================================
``weight`` ``weight`` Plastic, updated on each pre-spike
``delay`` ``delay`` Converted to ms, discretized to steps
``receptor_type`` ``receptor_type`` Integer ≥ 0
``tau_plus`` ``tau_plus`` Pre-synaptic trace decay (ms)
``tau_minus`` (neuron param) Here: synapse param ``tau_minus`` (ms)
``lambda`` ``lambda_`` Learning rate (underscore to avoid keyword)
``alpha`` ``alpha`` Depression asymmetry factor
``mu_plus`` ``mu_plus`` Potentiation exponent
``mu_minus`` ``mu_minus`` Depression exponent
``Wmax`` ``Wmax`` Weight upper bound (or lower for inhib.)
``Kplus`` ``Kplus`` Pre-synaptic trace state variable
================== ==================== ========================================
Attributes
----------
weight : float
Current synaptic weight (plastic, updated during simulation).
Kplus : float
Current presynaptic trace value.
t_lastspike : float
Timestamp (ms) of the most recent presynaptic spike.
tau_plus : float
Presynaptic trace time constant (ms).
tau_minus : float
Postsynaptic trace time constant (ms).
lambda_ : float
Learning rate.
alpha : float
Depression asymmetry factor.
mu_plus : float
Potentiation exponent.
mu_minus : float
Depression exponent.
Wmax : float
Maximum weight bound.
See Also
--------
static_synapse : Base class for non-plastic synapses
tsodyks_synapse : Short-term plasticity (depression/facilitation)
stdp_synapse_hom : Homogeneous-weight variant with shared weight across connections
Notes
-----
- The model transmits spike-like events only (``event_type='spike'``).
- ``update(pre_spike=..., post_spike=...)`` accepts both presynaptic and
postsynaptic spike multiplicities for standalone STDP simulation.
- ``record_post_spike(...)`` can be used to manually feed postsynaptic spikes
when the postsynaptic model does not expose NEST archiving APIs.
- Postsynaptic spike history grows unbounded; call ``clear_post_history()``
periodically in long simulations to prevent memory issues.
References
----------
.. [1] NEST source code: ``models/stdp_synapse.h`` and ``models/stdp_synapse.cpp``.
https://github.com/nest/nest-simulator
.. [2] Guetig R, Aharonov R, Rotter S, Sompolinsky H (2003).
Learning input correlations through nonlinear temporally asymmetric
Hebbian plasticity. *Journal of Neuroscience*, 23(9):3697-3714.
DOI: `10.1523/JNEUROSCI.23-09-03697.2003 <https://doi.org/10.1523/JNEUROSCI.23-09-03697.2003>`_
.. [3] Song S, Miller KD, Abbott LF (2000). Competitive Hebbian learning
through spike-timing-dependent synaptic plasticity.
*Nature Neuroscience*, 3(9):919-926.
DOI: `10.1038/78829 <https://doi.org/10.1038/78829>`_
.. [4] van Rossum MCW, Bi G-Q, Turrigiano GG (2000). Stable Hebbian learning
from spike timing-dependent plasticity.
*Journal of Neuroscience*, 20(23):8812-8821.
DOI: `10.1523/JNEUROSCI.20-23-08812.2000 <https://doi.org/10.1523/JNEUROSCI.20-23-08812.2000>`_
Examples
--------
**Basic STDP synapse with default parameters:**
.. code-block:: python
>>> import brainpy.state as bst
>>> import saiunit as u
>>> syn = bst.stdp_synapse(weight=0.5, delay=1.0 * u.ms)
>>> syn.get()
{'weight': 0.5, 'delay': 1.0, 'receptor_type': 0, 'tau_plus': 20.0,
'tau_minus': 20.0, 'lambda': 0.01, 'alpha': 1.0, 'mu_plus': 1.0,
'mu_minus': 1.0, 'Wmax': 100.0, 'Kplus': 0.0, 'synapse_model': 'stdp_synapse'}
**Asymmetric STDP (stronger depression):**
.. code-block:: python
>>> syn = bst.stdp_synapse(
... weight=1.0,
... tau_plus=16.8 * u.ms,
... tau_minus=33.7 * u.ms,
... lambda_=0.005,
... alpha=1.05, # 5% stronger depression
... Wmax=2.0,
... )
**Additive STDP (weight-independent updates):**
.. code-block:: python
>>> syn = bst.stdp_synapse(
... weight=0.5,
... mu_plus=0.0, # additive potentiation
... mu_minus=0.0, # additive depression
... lambda_=0.001,
... Wmax=1.0,
... )
**Manual postsynaptic spike recording:**
.. code-block:: python
>>> import brainstate
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... syn = bst.stdp_synapse(weight=1.0)
... syn.init_state()
... # Simulate postsynaptic spike at t=5.0 ms
... syn.record_post_spike(multiplicity=1, t_spike_ms=5.0)
... # Simulate presynaptic spike at t=10.0 ms (after post-spike)
... # This should potentiate the weight
... syn.send(multiplicity=1) # uses on-grid stamp t + dt
... print(f"Updated weight: {syn.weight:.6f}") # > 1.0 (potentiated)
"""
__module__ = 'brainpy.state'
def __init__(
self,
weight: ArrayLike = 1.0,
delay: ArrayLike = 1.0 * u.ms,
receptor_type: int = 0,
tau_plus: ArrayLike = 20.0 * u.ms,
tau_minus: ArrayLike = 20.0 * u.ms,
lambda_: ArrayLike = 0.01,
alpha: ArrayLike = 1.0,
mu_plus: ArrayLike = 1.0,
mu_minus: ArrayLike = 1.0,
Wmax: ArrayLike = 100.0,
Kplus: ArrayLike = 0.0,
post=None,
name: str | None = None,
):
weight_value = self._to_scalar_float(weight, name='weight')
super().__init__(
weight=weight_value,
delay=delay,
receptor_type=receptor_type,
post=post,
event_type='spike',
name=name,
)
self.tau_plus = self._to_scalar_time_ms(tau_plus, name='tau_plus')
self.tau_minus = self._to_scalar_time_ms(tau_minus, name='tau_minus')
self.lambda_ = self._to_scalar_float(lambda_, name='lambda')
self.alpha = self._to_scalar_float(alpha, name='alpha')
self.mu_plus = self._to_scalar_float(mu_plus, name='mu_plus')
self.mu_minus = self._to_scalar_float(mu_minus, name='mu_minus')
self.Wmax = self._to_scalar_float(Wmax, name='Wmax')
self.Kplus = self._to_scalar_float(Kplus, name='Kplus')
self._validate_weight_wmax_sign(weight_value, self.Wmax)
self._validate_non_negative(self.Kplus, name='Kplus')
self._Kplus0 = float(self.Kplus)
self._t_lastspike0 = 0.0
self.t_lastspike = float(self._t_lastspike0)
self._post_kminus = 0.0
self._last_post_spike = -1.0
self._post_hist_t: list[float] = []
self._post_hist_kminus: list[float] = []
@staticmethod
def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
unit = u.get_unit(value)
arr = np.asarray(value.to_decimal(unit), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr.reshape(()))
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
return v
@staticmethod
def _nest_sign(value: float) -> int:
# Matches NEST set_status sign check:
# ((x >= 0) - (x < 0)), so zero counts as positive.
return int(value >= 0.0) - int(value < 0.0)
@staticmethod
def _validate_non_negative(value: float, *, name: str):
if value < 0.0:
raise ValueError(f'{name} must be non-negative.')
@classmethod
def _validate_weight_wmax_sign(cls, weight: float, Wmax: float):
if cls._nest_sign(weight) != cls._nest_sign(Wmax):
raise ValueError('Weight and Wmax must have same sign.')
@staticmethod
def _to_non_negative_int_count(value: ArrayLike, *, name: str) -> int:
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr.reshape(()))
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
if v < 0.0:
raise ValueError(f'{name} must be non-negative.')
rounded = int(round(v))
if not math.isclose(v, float(rounded), rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be an integer spike count.')
return rounded
def _facilitate(self, w: float, kplus: float) -> float:
norm_w = (w / self.Wmax) + (self.lambda_ * math.pow(1.0 - (w / self.Wmax), self.mu_plus) * kplus)
return norm_w * self.Wmax if norm_w < 1.0 else self.Wmax
def _depress(self, w: float, kminus: float) -> float:
norm_w = (w / self.Wmax) - (self.alpha * self.lambda_ * math.pow(w / self.Wmax, self.mu_minus) * kminus)
return norm_w * self.Wmax if norm_w > 0.0 else 0.0
[docs]
def clear_post_history(self):
r"""Clear internal postsynaptic spike history and reset trace state.
Resets all postsynaptic STDP state to initial conditions:
- Clears spike history buffer (timestamps and trace values)
- Resets postsynaptic trace ``K^-`` to zero
- Resets last postsynaptic spike timestamp to ``-1.0``
This method should be called periodically in long simulations to prevent
unbounded growth of the spike history buffer. Typical usage: clear history
at the start of each trial or after weight convergence phases.
See Also
--------
init_state : Reinitialize all synapse state including weights and traces
record_post_spike : Record postsynaptic spikes into the history buffer
"""
self._post_kminus = 0.0
self._last_post_spike = -1.0
self._post_hist_t = []
self._post_hist_kminus = []
def _record_post_spike_at(self, t_spike_ms: float):
self._post_kminus = (
self._post_kminus * math.exp((self._last_post_spike - t_spike_ms) / self.tau_minus) + 1.0
)
self._last_post_spike = float(t_spike_ms)
self._post_hist_t.append(float(t_spike_ms))
self._post_hist_kminus.append(float(self._post_kminus))
[docs]
def record_post_spike(
self,
multiplicity: ArrayLike = 1.0,
*,
t_spike_ms: ArrayLike | None = None,
) -> int:
r"""Record postsynaptic spikes into the internal STDP history buffer.
This method updates the postsynaptic eligibility trace :math:`K^-` and stores
the spike timestamp for later use by :meth:`send` when processing presynaptic
spikes. Each recorded spike increments :math:`K^-` by 1.0 after exponential
decay from the previous postsynaptic spike.
The trace update follows:
.. math::
K^- \\leftarrow K^- \\cdot e^{(t_{\\text{last\\_post}} - t_{\\text{spike}}) / \\tau_-} + 1
where :math:`t_{\\text{last\\_post}}` is the timestamp of the previous postsynaptic
spike and :math:`\\tau_-` is the postsynaptic trace time constant.
Multiple spikes can be recorded by setting ``multiplicity > 1``. This is
equivalent to calling the method ``multiplicity`` times at the same timestamp.
Parameters
----------
multiplicity : float, array-like, or Quantity, optional
Number of postsynaptic spikes to record at this timestamp. Must be a
non-negative integer-valued scalar (fractional values will be rejected).
Use ``multiplicity=0`` to skip recording (returns immediately).
Default: ``1`` (single spike).
t_spike_ms : float, array-like, or Quantity, optional
Spike timestamp in milliseconds. Must be a scalar float with or without
time units. If ``None``, uses the current on-grid spike stamp
:math:`t + dt` where :math:`t` is the current simulation time and
:math:`dt` is the simulation time step. Default: ``None`` (on-grid).
Returns
-------
int
Number of spikes successfully recorded (equal to ``multiplicity``).
Raises
------
ValueError
- If ``multiplicity`` is negative or not integer-valued.
- If ``t_spike_ms`` is not a finite scalar.
See Also
--------
clear_post_history : Clear all postsynaptic spike history
send : Process presynaptic spike and apply STDP weight updates
Examples
--------
**Record single postsynaptic spike at current time:**
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... syn = bst.stdp_synapse(weight=1.0)
... syn.init_state()
... count = syn.record_post_spike() # uses t + dt
... print(count)
1
**Record postsynaptic spike at explicit timestamp:**
.. code-block:: python
>>> syn.record_post_spike(multiplicity=1, t_spike_ms=5.0)
1
**Record burst of 3 postsynaptic spikes:**
.. code-block:: python
>>> syn.record_post_spike(multiplicity=3)
3
"""
count = self._to_non_negative_int_count(multiplicity, name='post_spike')
if count == 0:
return 0
if t_spike_ms is None:
dt_ms = self._refresh_delay_if_needed()
t_value = self._current_time_ms() + dt_ms
else:
t_value = self._to_scalar_float(t_spike_ms, name='t_spike_ms')
for _ in range(count):
self._record_post_spike_at(float(t_value))
return count
def _get_post_history_times(self, t1_ms: float, t2_ms: float) -> list[float]:
t1_lim = float(t1_ms + _STDP_EPS)
t2_lim = float(t2_ms + _STDP_EPS)
selected = []
for t_post in self._post_hist_t:
if t_post >= t1_lim and t_post < t2_lim:
selected.append(t_post)
return selected
def _get_K_value(self, t_ms: float) -> float:
# Return trace strictly before t, matching ArchivingNode::get_K_value.
for idx in range(len(self._post_hist_t) - 1, -1, -1):
t_post = self._post_hist_t[idx]
if (t_ms - t_post) > _STDP_EPS:
return self._post_hist_kminus[idx] * math.exp((t_post - t_ms) / self.tau_minus)
return 0.0
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Initialize synapse state for simulation.
Resets all dynamic state variables to their initial values:
- ``Kplus``: presynaptic trace → initial value (default ``0.0``)
- ``t_lastspike``: last presynaptic spike time → ``0.0`` ms
- Postsynaptic spike history and trace → cleared
This method should be called before starting a new simulation or trial.
Inherits delay queue initialization from :class:`static_synapse`.
Parameters
----------
batch_size : int, optional
Ignored (provided for API compatibility with batched models).
**kwargs
Ignored (provided for API compatibility).
See Also
--------
clear_post_history : Clear only postsynaptic history without resetting other state
set : Update parameters without reinitializing state
"""
del batch_size, kwargs
super().init_state()
self.Kplus = float(self._Kplus0)
self.t_lastspike = float(self._t_lastspike0)
self.clear_post_history()
[docs]
def get(self) -> dict:
r"""Return current public parameters and mutable state.
Retrieves all NEST-compatible parameters and dynamic state variables in a
dictionary format suitable for inspection, logging, or serialization. Includes
both inherited parameters from :class:`static_synapse` (``weight``, ``delay``,
``receptor_type``) and STDP-specific parameters.
Returns
-------
dict
Dictionary containing:
- ``'weight'`` : float – Current synaptic weight (plastic)
- ``'delay'`` : float – Transmission delay (ms)
- ``'receptor_type'`` : int – Postsynaptic receptor port
- ``'tau_plus'`` : float – Presynaptic trace time constant (ms)
- ``'tau_minus'`` : float – Postsynaptic trace time constant (ms)
- ``'lambda'`` : float – Learning rate (key name without underscore)
- ``'alpha'`` : float – Depression asymmetry factor
- ``'mu_plus'`` : float – Potentiation exponent
- ``'mu_minus'`` : float – Depression exponent
- ``'Wmax'`` : float – Maximum weight bound
- ``'Kplus'`` : float – Current presynaptic trace value
- ``'synapse_model'`` : str – Always ``'stdp_synapse'`` (NEST identifier)
See Also
--------
set : Update parameters and state
init_state : Reinitialize state to defaults
Examples
--------
.. code-block:: python
>>> syn = bst.stdp_synapse(weight=0.5, lambda_=0.01)
>>> params = syn.get()
>>> params['weight']
0.5
>>> params['lambda']
0.01
>>> params['synapse_model']
'stdp_synapse'
"""
params = super().get()
params['tau_plus'] = float(self.tau_plus)
params['tau_minus'] = float(self.tau_minus)
params['lambda'] = float(self.lambda_)
params['alpha'] = float(self.alpha)
params['mu_plus'] = float(self.mu_plus)
params['mu_minus'] = float(self.mu_minus)
params['Wmax'] = float(self.Wmax)
params['Kplus'] = float(self.Kplus)
params['synapse_model'] = 'stdp_synapse'
return params
[docs]
def set(
self,
*,
weight: ArrayLike | object = _UNSET,
delay: ArrayLike | object = _UNSET,
receptor_type: ArrayLike | object = _UNSET,
tau_plus: ArrayLike | object = _UNSET,
tau_minus: ArrayLike | object = _UNSET,
lambda_: ArrayLike | object = _UNSET,
alpha: ArrayLike | object = _UNSET,
mu_plus: ArrayLike | object = _UNSET,
mu_minus: ArrayLike | object = _UNSET,
Wmax: ArrayLike | object = _UNSET,
Kplus: ArrayLike | object = _UNSET,
post: object = _UNSET,
):
r"""Set NEST-style public parameters and mutable state.
Updates one or more synapse parameters and state variables without reinitializing
the full simulation state. Mimics NEST's ``SetStatus`` API. All parameters are
validated before assignment to ensure consistency (e.g., ``weight`` and ``Wmax``
must have the same sign).
Only specified parameters are updated; unspecified parameters retain their
current values. To reset all state to initial conditions, use :meth:`init_state`
instead.
Parameters
----------
weight : float, array-like, or Quantity, optional
New synaptic weight. Must have the same sign as ``Wmax`` (or new ``Wmax``
if both are specified). Validated before assignment. Default: unchanged.
delay : float, array-like, or Quantity, optional
New transmission delay (ms). Must be positive. Will be discretized to
integer time steps on next usage. Default: unchanged.
receptor_type : int, optional
New receptor port identifier. Must be non-negative. Default: unchanged.
tau_plus : float, array-like, or Quantity, optional
New presynaptic trace time constant (ms). Must be positive.
Default: unchanged.
tau_minus : float, array-like, or Quantity, optional
New postsynaptic trace time constant (ms). Must be positive.
Default: unchanged.
lambda_ : float, array-like, or Quantity, optional
New learning rate. Typically positive. Default: unchanged.
alpha : float, array-like, or Quantity, optional
New depression asymmetry factor. Typically non-negative. Default: unchanged.
mu_plus : float, array-like, or Quantity, optional
New potentiation exponent. Must be non-negative. Default: unchanged.
mu_minus : float, array-like, or Quantity, optional
New depression exponent. Must be non-negative. Default: unchanged.
Wmax : float, array-like, or Quantity, optional
New maximum weight bound. Must have the same sign as ``weight`` (or new
``weight`` if both are specified). Default: unchanged.
Kplus : float, array-like, or Quantity, optional
New presynaptic trace value. Must be non-negative. Typically used to
restore saved state rather than manipulate during simulation.
Default: unchanged.
post : Dynamics, optional
New default postsynaptic receiver. Default: unchanged.
Raises
------
ValueError
- If ``weight`` and ``Wmax`` have different signs.
- If ``Kplus`` is negative.
- If any parameter has non-finite values or incorrect shape.
See Also
--------
get : Retrieve current parameters and state
init_state : Reinitialize all state to defaults
Examples
--------
**Update learning rate during simulation:**
.. code-block:: python
>>> syn = bst.stdp_synapse(weight=1.0, lambda_=0.01)
>>> syn.set(lambda_=0.001) # reduce learning rate
>>> syn.get()['lambda']
0.001
**Update multiple parameters atomically:**
.. code-block:: python
>>> syn.set(
... weight=0.5,
... Wmax=2.0,
... alpha=1.1,
... )
**Restore saved state:**
.. code-block:: python
>>> saved_params = syn.get()
>>> # ... simulation ...
>>> syn.set(**{k: v for k, v in saved_params.items()
... if k != 'synapse_model'}) # restore all except model name
"""
new_weight = self.weight if weight is _UNSET else self._to_scalar_float(weight, name='weight')
new_tau_plus = (
self.tau_plus
if tau_plus is _UNSET
else self._to_scalar_time_ms(tau_plus, name='tau_plus')
)
new_tau_minus = (
self.tau_minus
if tau_minus is _UNSET
else self._to_scalar_time_ms(tau_minus, name='tau_minus')
)
new_lambda = (
self.lambda_
if lambda_ is _UNSET
else self._to_scalar_float(lambda_, name='lambda')
)
new_alpha = self.alpha if alpha is _UNSET else self._to_scalar_float(alpha, name='alpha')
new_mu_plus = (
self.mu_plus
if mu_plus is _UNSET
else self._to_scalar_float(mu_plus, name='mu_plus')
)
new_mu_minus = (
self.mu_minus
if mu_minus is _UNSET
else self._to_scalar_float(mu_minus, name='mu_minus')
)
new_Wmax = self.Wmax if Wmax is _UNSET else self._to_scalar_float(Wmax, name='Wmax')
new_Kplus = self.Kplus if Kplus is _UNSET else self._to_scalar_float(Kplus, name='Kplus')
self._validate_weight_wmax_sign(float(new_weight), float(new_Wmax))
self._validate_non_negative(float(new_Kplus), name='Kplus')
super_kwargs = {}
if weight is not _UNSET:
super_kwargs['weight'] = float(new_weight)
if delay is not _UNSET:
super_kwargs['delay'] = delay
if receptor_type is not _UNSET:
super_kwargs['receptor_type'] = receptor_type
if post is not _UNSET:
super_kwargs['post'] = post
if super_kwargs:
super().set(**super_kwargs)
self.tau_plus = float(new_tau_plus)
self.tau_minus = float(new_tau_minus)
self.lambda_ = float(new_lambda)
self.alpha = float(new_alpha)
self.mu_plus = float(new_mu_plus)
self.mu_minus = float(new_mu_minus)
self.Wmax = float(new_Wmax)
self.Kplus = float(new_Kplus)
self._Kplus0 = float(self.Kplus)
[docs]
def send(
self,
multiplicity: ArrayLike = 1.0,
*,
post=None,
receptor_type: ArrayLike | None = None,
) -> bool:
r"""Schedule one outgoing event with NEST ``stdp_synapse`` dynamics.
Processes a presynaptic spike, applies STDP weight updates (facilitation from
past postsynaptic spikes and depression from current spike), then schedules
the event for delayed delivery to the postsynaptic receiver. This method
replicates the exact update sequence from NEST ``models/stdp_synapse.h::send()``.
**Update sequence:**
1. Compute inter-spike interval :math:`h = t_{\text{spike}} - t_{\text{last}}`
2. Retrieve postsynaptic spikes in window :math:`(t_{\\text{last}} - d,\\, t_{\\text{spike}} - d]`
3. Apply facilitation for each retrieved postsynaptic spike (post-before-pre)
4. Compute postsynaptic trace :math:`K^-` at :math:`t_{\text{spike}} - d`
5. Apply depression based on :math:`K^-` (pre-before-post)
6. Schedule weighted event for delivery at :math:`t_{\text{spike}} + \text{delay}`
7. Update presynaptic trace: :math:`K^+ \\leftarrow K^+ \\cdot e^{-h/\\tau_+} + 1`
8. Update last spike timestamp: :math:`t_{\\text{last}} \\leftarrow t_{\\text{spike}}`
The final delivered weight is :math:`w_{\\text{eff}} = w \\cdot \\text{multiplicity}`
where :math:`w` is the plasticity-updated weight.
Parameters
----------
multiplicity : float, array-like, or Quantity, optional
Presynaptic spike multiplicity (event magnitude). Scalar value, typically
``1.0`` for a single spike or ``0.0`` to skip transmission. The delivered
payload is scaled by this factor. Default: ``1.0``.
post : Dynamics, optional
Postsynaptic receiver object for this event. If ``None``, uses the default
receiver specified during initialization. Must implement ``add_delta_input``
or handle spike events. Default: ``None`` (use default receiver).
receptor_type : int, optional
Receptor port override for this event. If ``None``, uses the synapse's
default ``receptor_type``. Default: ``None`` (use default receptor).
Returns
-------
bool
``True`` if an event was scheduled (``multiplicity != 0``), ``False`` otherwise.
Raises
------
ValueError
- If ``multiplicity`` is not a finite scalar.
- If ``receptor_type`` is negative or not an integer.
- If no postsynaptic receiver is available (neither ``post`` argument nor
default receiver specified).
See Also
--------
update : High-level method combining event delivery, post-spike recording, and sending
record_post_spike : Manually record postsynaptic spikes for STDP
Notes
-----
- Spike timestamp uses on-grid time :math:`t + dt` (NEST convention).
- Dendritic delay :math:`d` shifts the STDP causality window backward in time.
- Postsynaptic spike history is never cleared by this method; call
:meth:`clear_post_history` periodically to prevent memory growth.
Examples
--------
**Send single presynaptic spike:**
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... post_neuron = bst.LIF(1)
... syn = bst.stdp_synapse(weight=1.0, post=post_neuron)
... syn.init_state()
... post_neuron.init_state()
... success = syn.send(multiplicity=1.0)
... print(success)
True
**Send presynaptic spike with receptor override:**
.. code-block:: python
>>> syn.send(multiplicity=1.0, receptor_type=1) # target receptor port 1
True
**Skip transmission (zero multiplicity):**
.. code-block:: python
>>> syn.send(multiplicity=0.0)
False
"""
if not self._is_nonzero(multiplicity):
return False
dt_ms = self._refresh_delay_if_needed()
current_step = self._curr_step(dt_ms)
# NEST uses on-grid event stamps in this model.
t_spike = self._current_time_ms() + dt_ms
dendritic_delay = float(self.delay)
# Facilitation due to postsynaptic spikes in
# (t_lastspike - dendritic_delay, t_spike - dendritic_delay].
t1 = self.t_lastspike - dendritic_delay
t2 = t_spike - dendritic_delay
for t_post in self._get_post_history_times(t1, t2):
minus_dt = self.t_lastspike - (t_post + dendritic_delay)
assert minus_dt < (-1.0 * _STDP_EPS)
kplus_term = self.Kplus * math.exp(minus_dt / self.tau_plus)
self.weight = float(self._facilitate(float(self.weight), float(kplus_term)))
# Depression due to current presynaptic spike.
kminus_value = self._get_K_value(t_spike - dendritic_delay)
self.weight = float(self._depress(float(self.weight), float(kminus_value)))
receiver = self._resolve_receiver(post)
rport = self.receptor_type if receptor_type is None else self._to_receptor_type(receptor_type)
weighted_payload = multiplicity * float(self.weight)
delivery_step = int(current_step + int(self._delay_steps))
self._queue[delivery_step].append((receiver, weighted_payload, int(rport), 'spike'))
self.Kplus = float(self.Kplus * math.exp((self.t_lastspike - t_spike) / self.tau_plus) + 1.0)
self.t_lastspike = float(t_spike)
return True
[docs]
def update(
self,
pre_spike: ArrayLike = 0.0,
*,
post_spike: ArrayLike = 0.0,
post=None,
receptor_type: ArrayLike | None = None,
) -> int:
r"""Deliver due events, update post history, then process pre spikes.
High-level update method combining all STDP synapse operations for a single
simulation time step. This method is typically called once per time step in
network simulations and handles:
1. Delivery of delayed events from previous time steps
2. Recording of postsynaptic spikes into the STDP history buffer
3. Aggregation of presynaptic inputs (from ``current_inputs`` and ``delta_inputs``)
4. STDP weight update and event scheduling via :meth:`send`
The update order ensures correct causality: delayed events are delivered before
processing new spikes, and postsynaptic spikes are recorded before presynaptic
spikes are processed (allowing immediate STDP updates if delay is minimal).
**Update sequence:**
**Step 1: Deliver due events**
Check the internal delay queue and deliver all events scheduled for the
current simulation step to their target receivers.
**Step 2: Record postsynaptic spikes**
If ``post_spike > 0``, record ``post_spike`` postsynaptic spikes at
timestamp :math:`t + dt` into the STDP history buffer. This updates the
postsynaptic trace :math:`K^-`.
**Step 3: Aggregate presynaptic inputs**
Sum inputs from:
- ``pre_spike`` argument (explicit input)
- ``current_inputs`` dict (accumulated continuous inputs)
- ``delta_inputs`` dict (accumulated spike inputs)
**Step 4: Process presynaptic spike**
If aggregated input is non-zero, call :meth:`send` to apply STDP weight
updates and schedule a new delayed event.
Parameters
----------
pre_spike : float, array-like, or Quantity, optional
Presynaptic spike multiplicity (explicit input). Added to accumulated
inputs from ``current_inputs`` and ``delta_inputs``. Typically ``0.0``
(no explicit input) or ``1.0`` (single spike). Default: ``0.0``.
post_spike : float, array-like, or Quantity, optional
Postsynaptic spike multiplicity to record. Must be a non-negative
integer-valued scalar. If ``> 0``, records the specified number of
postsynaptic spikes at the current on-grid timestamp :math:`t + dt`.
Default: ``0.0`` (no postsynaptic spikes).
post : Dynamics, optional
Postsynaptic receiver object for event delivery. If ``None``, uses the
default receiver specified during initialization. Default: ``None``.
receptor_type : int, optional
Receptor port override for event delivery. If ``None``, uses the synapse's
default ``receptor_type``. Default: ``None``.
Returns
-------
int
Number of events delivered during this time step (from the delay queue).
Does **not** include the newly scheduled event from this time step's
presynaptic spike (that event will be counted in a future time step).
Raises
------
ValueError
- If ``post_spike`` is negative or not integer-valued.
- If ``pre_spike`` or aggregated inputs are not finite scalars.
See Also
--------
send : Low-level method for processing a single presynaptic spike
record_post_spike : Record postsynaptic spikes without other update operations
Notes
-----
- This method modifies synapse state (``weight``, ``Kplus``, ``t_lastspike``,
postsynaptic history) and should be called exactly once per time step.
- The returned delivery count reflects past events, not the current time step's
transmission.
- For standalone STDP testing without a network, manually call :meth:`record_post_spike`
and :meth:`send` instead of relying on :meth:`update`.
Examples
--------
**Typical usage in network simulation loop:**
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... pre = bst.LIF(1)
... post = bst.LIF(1)
... syn = bst.stdp_synapse(weight=1.0, post=post)
... pre.init_state()
... post.init_state()
... syn.init_state()
... # Simulation step: presynaptic spike, no postsynaptic spike
... delivered = syn.update(pre_spike=1.0, post_spike=0.0)
... # Simulation step: no presynaptic spike, postsynaptic spike
... delivered = syn.update(pre_spike=0.0, post_spike=1.0)
**Standalone STDP test with explicit spike times:**
.. code-block:: python
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... syn = bst.stdp_synapse(weight=1.0, tau_plus=20.0*u.ms, tau_minus=20.0*u.ms)
... syn.init_state()
... # Post-before-pre: potentiation expected
... syn.record_post_spike(multiplicity=1, t_spike_ms=5.0)
... syn.send(multiplicity=1) # pre-spike at t + dt (uses on-grid time)
... print(f"Weight after potentiation: {syn.weight:.6f}") # > 1.0
"""
dt_ms = self._refresh_delay_if_needed()
step = self._curr_step(dt_ms)
delivered = self._deliver_due_events(step)
post_count = self._to_non_negative_int_count(post_spike, name='post_spike')
if post_count > 0:
t_post = self._current_time_ms() + dt_ms
for _ in range(post_count):
self._record_post_spike_at(float(t_post))
total_pre = self.sum_current_inputs(pre_spike)
total_pre = self.sum_delta_inputs(total_pre)
if self._is_nonzero(total_pre):
self.send(total_pre, post=post, receptor_type=receptor_type)
return delivered