# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from collections.abc import Mapping
import brainstate
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike
from .static_synapse import static_synapse
__all__ = [
'stdp_pl_synapse_hom',
]
_UNSET = object()
_STDP_EPS = 1.0e-6
class stdp_pl_synapse_hom(static_synapse):
r"""NEST-compatible ``stdp_pl_synapse_hom`` connection model.
``stdp_pl_synapse_hom`` implements the power-law spike-timing-dependent
plasticity (STDP) rule from Morrison et al. (2007) with homogeneous
plasticity parameters. This synapse exhibits asymmetric potentiation and
depression with non-linear, power-law weight dependence, making it suitable
for modeling balanced networks with realistic weight distributions.
The model replicates NEST ``models/stdp_pl_synapse_hom.h`` exactly, including
propagator computation, update ordering, and event timing semantics. Delay
scheduling and receiver delivery inherit from :class:`static_synapse`.
**1. Mathematical Model**
State Variables
---------------
- ``weight`` (:math:`w`): Synaptic efficacy (current/conductance units or dimensionless)
- ``Kplus`` (:math:`K^+`): Presynaptic eligibility trace (dimensionless)
- ``t_lastspike`` (:math:`t_{\mathrm{last}}`): Timestamp of previous presynaptic spike (ms)
- Internal postsynaptic history buffer: ``(t_post, K^-(t_post))`` pairs
**Continuous-time dynamics (between spikes):**
Presynaptic trace decay:
.. math::
\frac{dK^+}{dt} = -\frac{K^+}{\tau_+}
Postsynaptic trace decay (maintained in internal buffer):
.. math::
\frac{dK^-}{dt} = -\frac{K^-}{\tau_-}
where:
- :math:`\tau_+ > 0` -- Potentiation time constant (ms)
- :math:`\tau_- > 0` -- Depression time constant (ms)
**Upon presynaptic spike at time** :math:`t_{\mathrm{pre}}` **with dendritic delay** :math:`d`:
**Step 1: Facilitation (Potentiation)** — Process all postsynaptic spikes in the causal window:
For each postsynaptic spike :math:`t_{\mathrm{post}}` in the interval
:math:`(t_{\mathrm{last}} - d,\, t_{\mathrm{pre}} - d]`:
.. math::
K^+_{\mathrm{eff}} = K^+ \cdot \exp\left(\frac{t_{\mathrm{last}} - (t_{\mathrm{post}} + d)}{\tau_+}\right)
w \leftarrow w + \lambda \, w^\mu \, K^+_{\mathrm{eff}}
where:
- :math:`\lambda` -- Learning rate (dimensionless)
- :math:`\mu` -- Power-law exponent for potentiation (:math:`\mu \in [0, 1]` typical)
**Interpretation:** The presynaptic trace :math:`K^+` is back-propagated to
the time of the postsynaptic spike (:math:`t_{\mathrm{post}} + d`, accounting
for dendritic delay), producing a smaller effective trace for older postsynaptic
spikes. Potentiation is **multiplicative** and **sub-linear** in weight
(:math:`w^\mu` with :math:`\mu < 1`), promoting stable weight distributions.
**Step 2: Depression** — Apply depression based on the postsynaptic trace at the pre-spike time:
.. math::
K^-_{\mathrm{eff}} = K^-\left(t_{\mathrm{pre}} - d\right)
w \leftarrow w - \alpha \lambda \, w \, K^-_{\mathrm{eff}}
w \leftarrow \max(w, 0)
where :math:`\alpha` is the depression scaling factor.
**Interpretation:** Depression is **linear** in weight and occurs when a
presynaptic spike is preceded by postsynaptic activity. The weight is clipped
to zero to prevent negative values.
**Step 3: Event Transmission** — Schedule the weighted event with updated ``weight``.
**Step 4: Presynaptic Trace Update:**
.. math::
K^+ \leftarrow K^+ \cdot \exp\left(\frac{t_{\mathrm{last}} - t_{\mathrm{pre}}}{\tau_+}\right) + 1
t_{\mathrm{last}} \leftarrow t_{\mathrm{pre}}
**Postsynaptic spike handling (via internal buffer):**
Upon postsynaptic spike at :math:`t_{\mathrm{post}}`:
.. math::
K^- \leftarrow K^- \cdot \exp\left(\frac{t_{\mathrm{last\_post}} - t_{\mathrm{post}}}{\tau_-}\right) + 1
Stored as ``(t_post, K^-)`` in history buffer for future lookups.
**2. Update Ordering and NEST Compatibility**
This implementation preserves the exact update sequence from NEST
``models/stdp_pl_synapse_hom.h::send()``:
1. Read postsynaptic spike history in :math:`(t_{\mathrm{last}} - d,\, t_{\mathrm{pre}} - d]`
2. For each retrieved postsynaptic spike, compute back-propagated :math:`K^+_{\mathrm{eff}}`
3. Apply facilitation: :math:`w \leftarrow w + \lambda w^\mu K^+_{\mathrm{eff}}`
4. Retrieve depression trace :math:`K^-_{\mathrm{eff}}` at :math:`t_{\mathrm{pre}} - d`
5. Apply depression: :math:`w \leftarrow \max(w - \alpha \lambda w K^-_{\mathrm{eff}}, 0)`
6. Schedule weighted spike event
7. Update presynaptic trace: :math:`K^+ \leftarrow K^+ e^{(t_{\mathrm{last}} - t_{\mathrm{pre}})/\tau_+} + 1`
8. Update timestamp: :math:`t_{\mathrm{last}} \leftarrow t_{\mathrm{pre}}`
**3. Homogeneous-Property Semantics**
In NEST, ``tau_plus``, ``lambda``, ``alpha``, and ``mu`` are **common model properties**
shared by all synapses of this type, while ``weight`` and ``Kplus``
are **per-connection state**.
This implementation enforces NEST connect-time semantics:
- Common properties (``tau_plus``, ``lambda``, ``alpha``, ``mu``) are set
at model instantiation or via ``SetDefaults()`` / ``CopyModel()``
- Per-connection properties (``weight``, ``Kplus``) can be set via
``Connect(..., syn_spec={...})``
- :meth:`check_synapse_params` rejects attempts to override common properties
in connection specifications
**4. Event Timing Semantics**
NEST evaluates this model using on-grid spike time stamps and ignores precise
sub-step offsets. This implementation follows the same convention:
- Presynaptic spike detected at simulation step ``n``
- Spike time stamp: :math:`t_{\mathrm{spike}} = t_n + dt`
- Dendritic arrival time: :math:`t_{\mathrm{arrival}} = t_{\mathrm{spike}} - d`
- Delivery time: :math:`t_{\mathrm{delivery}} = t_{\mathrm{spike}} + \mathrm{delay}`
**5. Stability Constraints and Computational Implications**
**Parameter Constraints:**
- :math:`\tau_+ > 0` (enforced in ``__init__`` and ``set``)
- :math:`\tau_- > 0` (recommended, not enforced)
- :math:`\lambda \geq 0` (learning rate)
- :math:`\alpha \geq 0` (depression scaling)
- :math:`\mu \in [0, 1]` (typical range; not enforced)
- :math:`K^+ \geq 0` (initial presynaptic trace; typically zero)
- :math:`w \geq 0` (maintained via clipping in depression)
Numerical Considerations
------------------------
- Trace propagation uses ``math.exp()`` for exponential decay
- Power-law computation uses ``numpy.power()`` with float64 precision
- Postsynaptic history is stored as Python lists ``_post_hist_t`` and
``_post_hist_kminus``; lookups are :math:`O(n)` where :math:`n` is the
number of stored postsynaptic spikes
- Per-call cost: :math:`O(n_{\mathrm{post}})` where :math:`n_{\mathrm{post}}`
is the number of postsynaptic spikes in the causal window
- All state variables are Python floats (``float64`` precision)
**Behavioral Regimes:**
- **Power-law stabilization** (:math:`\mu < 1`): Potentiation is sub-linear in
weight, preventing runaway growth and promoting log-normal weight distributions
(Morrison et al., 2007)
- **Balanced networks**: The combination of power-law potentiation and linear
depression naturally regulates weight distributions in recurrent networks
- **Weight clamping**: Depression clipping at :math:`w = 0` prevents negative
weights; no upper bound is enforced (unlike ``stdp_synapse`` with ``Wmax``)
Failure Modes
-------------
- **Non-finite weights**: Power-law computation :math:`w^\mu` can produce
``inf`` or ``nan`` for extreme weights; users should monitor weight distributions
- **Trace overflow**: Large spike trains can accumulate unbounded :math:`K^+`
or :math:`K^-` values (not a practical issue for typical firing rates)
- **History buffer growth**: Postsynaptic spike history is not pruned; long
simulations with high postsynaptic firing rates may consume memory
Parameters
----------
weight : ArrayLike, optional
Initial synaptic weight :math:`w` (dimensionless or with receiver-specific units).
Scalar float or array-like. Must be non-negative. Default: ``1.0``.
delay : ArrayLike, optional
Synaptic transmission delay :math:`d` in milliseconds. Must be ``> 0``.
Quantized to integer time steps per :class:`static_synapse` conventions.
Default: ``1.0 * u.ms``.
receptor_type : int, optional
Receiver port/receptor identifier (non-negative integer).
Default: ``0``.
tau_plus : ArrayLike, optional
Potentiation time constant :math:`\tau_+` in milliseconds. Must be ``> 0``.
Scalar float or saiunit ``Quantity``. **Common property** (not per-connection).
Default: ``20.0 * u.ms``.
tau_minus : ArrayLike, optional
Depression trace time constant :math:`\tau_-` in milliseconds. Must be ``> 0``.
Scalar float or saiunit ``Quantity``.
In NEST, this parameter belongs to the postsynaptic ``ArchivingNode``; here
it is stored on the synapse for standalone compatibility.
Default: ``20.0 * u.ms``.
lambda_ : ArrayLike, optional
Learning rate :math:`\lambda` (dimensionless). Must be non-negative.
**Common property** (not per-connection). Default: ``0.1``.
alpha : ArrayLike, optional
Depression scaling factor :math:`\alpha` (dimensionless). Must be non-negative.
Controls the relative strength of depression vs. potentiation.
**Common property** (not per-connection). Default: ``1.0``.
mu : ArrayLike, optional
Power-law exponent :math:`\mu` for potentiation (dimensionless).
Typical range: :math:`[0, 1]`. Values :math:`< 1` produce sub-linear
potentiation; :math:`\mu = 0` disables weight dependence.
**Common property** (not per-connection). Default: ``0.4``.
Kplus : ArrayLike, optional
Initial presynaptic eligibility trace :math:`K^+` (dimensionless).
Must be non-negative. Scalar float or array-like.
**Per-connection state**. Default: ``0.0``.
post : object, optional
Default receiver object (typically a neuron or neuron group).
Can be overridden in ``send()`` and ``update()`` calls. Default: ``None``.
name : str, optional
Object name for identification and debugging. Default: ``None``.
Parameter Mapping
-----------------
The following table maps NEST parameter names to this implementation:
======================== ======================== =================
NEST Parameter brainpy.state Parameter Type
======================== ======================== =================
``weight`` ``weight`` per-connection
``delay`` ``delay`` per-connection
``receptor_type`` ``receptor_type`` per-connection
``tau_plus`` ``tau_plus`` common property
``tau_minus`` ``tau_minus`` common property
``lambda`` ``lambda_`` common property
``alpha`` ``alpha`` common property
``mu`` ``mu`` common property
``Kplus`` ``Kplus`` per-connection
======================== ======================== =================
Notes
-----
- The model transmits spike-like events only (no graded signals).
- ``update(pre_spike=..., post_spike=...)`` accepts both presynaptic and
postsynaptic spike multiplicities (integer counts) for standalone STDP
simulation without explicit neuron models.
- ``record_post_spike(multiplicity, t_spike_ms=None)`` can be used to
manually feed postsynaptic spikes when the postsynaptic model does not
expose NEST ``ArchivingNode`` APIs.
- Postsynaptic spike history is **not automatically pruned**; users may call
``clear_post_history()`` to reset internal buffers if needed.
- Unlike ``stdp_synapse``, this model has **no upper weight bound** (``Wmax``);
weight stability relies on power-law potentiation dynamics.
See Also
--------
stdp_synapse : Classical pair-based STDP with separate potentiation/depression exponents
stdp_triplet_synapse : Triplet STDP rule (Pfister-Gerstner)
static_synapse : Base class for event scheduling and delay handling
References
----------
.. [1] NEST source: ``models/stdp_pl_synapse_hom.h`` and
``models/stdp_pl_synapse_hom.cpp``.
.. [2] Morrison A, Aertsen A, Diesmann M (2007). Spike-timing dependent
plasticity in balanced random networks.
Neural Computation, 19(6):1437-1467.
DOI: 10.1162/neco.2007.19.6.1437
Examples
--------
**Basic standalone STDP simulation:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> # Create synapse with power-law STDP
>>> syn = bps.stdp_pl_synapse_hom(
... weight=1.0,
... tau_plus=20*u.ms,
... tau_minus=20*u.ms,
... lambda_=0.1,
... alpha=1.0,
... mu=0.4,
... )
>>> syn.init_state()
>>>
>>> # Simulate pre-before-post pairing (potentiation)
>>> syn.record_post_spike(t_spike_ms=10.0) # post spike at 10 ms
>>> syn.send(1.0) # pre spike at 11 ms (assuming dt=1ms, t=10ms)
>>> print(f"Weight after potentiation: {syn.weight:.4f}")
Weight after potentiation: 1.0xxx
>>>
>>> # Simulate post-before-pre pairing (depression)
>>> syn.record_post_spike(t_spike_ms=20.0) # post spike at 20 ms
>>> syn.send(1.0) # pre spike at 10 ms (causally follows post)
>>> print(f"Weight after depression: {syn.weight:.4f}")
Weight after depression: 0.9xxx
**Enforcing homogeneous-property semantics:**
.. code-block:: python
>>> import brainpy.state as bps
>>>
>>> syn = bps.stdp_pl_synapse_hom(lambda_=0.05)
>>>
>>> # Allowed: per-connection properties
>>> syn.check_synapse_params({'weight': 2.0, 'Kplus': 0.5}) # OK
>>>
>>> # Disallowed: common properties in connection specs
>>> try:
... syn.check_synapse_params({'lambda': 0.1})
... except ValueError as e:
... print(e)
lambda cannot be specified in connect-time synapse parameters...
"""
__module__ = 'brainpy.state'
def __init__(
self,
weight: ArrayLike = 1.0,
delay: ArrayLike = 1.0 * u.ms,
receptor_type: int = 0,
tau_plus: ArrayLike = 20.0 * u.ms,
tau_minus: ArrayLike = 20.0 * u.ms,
lambda_: ArrayLike = 0.1,
alpha: ArrayLike = 1.0,
mu: ArrayLike = 0.4,
Kplus: ArrayLike = 0.0,
post=None,
name: str | None = None,
):
super().__init__(
weight=weight,
delay=delay,
receptor_type=receptor_type,
post=post,
event_type='spike',
name=name,
)
self.tau_plus = self._to_scalar_time_ms(tau_plus, name='tau_plus')
self.tau_minus = self._to_scalar_time_ms(tau_minus, name='tau_minus')
self.lambda_ = self._to_scalar_float(lambda_, name='lambda')
self.alpha = self._to_scalar_float(alpha, name='alpha')
self.mu = self._to_scalar_float(mu, name='mu')
self.Kplus = self._to_scalar_float(Kplus, name='Kplus')
self._validate_tau_plus(self.tau_plus)
self._Kplus0 = float(self.Kplus)
self._t_lastspike0 = 0.0
self.t_lastspike = float(self._t_lastspike0)
self._post_kminus = 0.0
self._last_post_spike = -1.0
self._post_hist_t: list[float] = []
self._post_hist_kminus: list[float] = []
@staticmethod
def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
unit = u.get_unit(value)
arr = np.asarray(value.to_decimal(unit), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr.reshape(()))
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
return v
@staticmethod
def _to_non_negative_int_count(value: ArrayLike, *, name: str) -> int:
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr.reshape(()))
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
if v < 0.0:
raise ValueError(f'{name} must be non-negative.')
rounded = int(round(v))
if not math.isclose(v, float(rounded), rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be an integer spike count.')
return rounded
@staticmethod
def _validate_tau_plus(value: float):
if value <= 0.0:
raise ValueError('tau_plus must be > 0.')
def _facilitate(self, w: float, kplus: float) -> float:
power_term = float(np.power(np.float64(w), np.float64(self.mu)))
return w + (self.lambda_ * power_term * kplus)
def _depress(self, w: float, kminus: float) -> float:
new_w = w - (self.lambda_ * self.alpha * w * kminus)
return new_w if new_w > 0.0 else 0.0
[docs]
def clear_post_history(self):
r"""Clear internal postsynaptic STDP history state.
Resets the internal postsynaptic spike history buffer and depression trace
to initial conditions. This method is useful for:
- Resetting the synapse state between simulation trials
- Reclaiming memory after long simulations with high postsynaptic firing rates
- Debugging and testing STDP dynamics
The method resets:
- ``_post_kminus``: Depression trace to ``0.0``
- ``_last_post_spike``: Last postsynaptic spike time to ``-1.0``
- ``_post_hist_t``: Spike time history to empty list
- ``_post_hist_kminus``: Depression trace history to empty list
Presynaptic state (``Kplus``, ``t_lastspike``) is **not** affected.
Notes
-----
- This method does **not** reset ``weight`` or presynaptic trace ``Kplus``
- Called automatically by ``init_state()``
- Postsynaptic history is **not** automatically pruned during simulation;
manual calls to this method may be needed for very long runs
See Also
--------
init_state : Full state initialization including history clearing
record_post_spike : Add postsynaptic spikes to history buffer
"""
self._post_kminus = 0.0
self._last_post_spike = -1.0
self._post_hist_t = []
self._post_hist_kminus = []
def _record_post_spike_at(self, t_spike_ms: float):
self._post_kminus = (
self._post_kminus * math.exp((self._last_post_spike - t_spike_ms) / self.tau_minus) + 1.0
)
self._last_post_spike = float(t_spike_ms)
self._post_hist_t.append(float(t_spike_ms))
self._post_hist_kminus.append(float(self._post_kminus))
[docs]
def record_post_spike(
self,
multiplicity: ArrayLike = 1.0,
*,
t_spike_ms: ArrayLike | None = None,
) -> int:
r"""Record postsynaptic spikes into the internal STDP history buffer.
This method manually adds postsynaptic spike events to the internal history
buffer used for STDP computation. It is intended for standalone STDP
simulation when the postsynaptic neuron does not expose NEST
``ArchivingNode`` APIs.
For each spike, the method:
1. Updates the depression trace:
:math:`K^- \leftarrow K^- \exp((t_{\mathrm{last}} - t_{\mathrm{spike}})/\tau_-) + 1`
2. Stores the spike time and trace value in the history buffer
Parameters
----------
multiplicity : ArrayLike, optional
Number of spikes to record (non-negative integer count).
If ``< 1.0``, no spikes are recorded. Default: ``1.0``.
t_spike_ms : ArrayLike or None, optional
Spike time stamp in milliseconds (scalar float or saiunit ``Quantity``).
If ``None``, uses the current simulation time plus one time step:
:math:`t_{\mathrm{spike}} = t_{\mathrm{current}} + dt`.
Default: ``None``.
Returns
-------
int
Number of spikes actually recorded (integer count).
Raises
------
ValueError
If ``multiplicity`` is not a scalar, not finite, negative, or not
close to an integer value.
ValueError
If ``t_spike_ms`` is provided but not a scalar or not finite.
Notes
-----
- Multiple spikes at the same time are recorded sequentially, updating
the trace after each spike (matches NEST behavior for simultaneous spikes)
- Spike times are stored in milliseconds (Python float)
- The internal history buffer grows unbounded; call ``clear_post_history()``
to reclaim memory if needed
- This method does **not** trigger STDP weight updates; updates occur
during presynaptic spike processing in ``send()``
See Also
--------
clear_post_history : Reset postsynaptic spike history buffer
update : Main update method that accepts ``post_spike`` parameter
send : Presynaptic spike processing (applies STDP weight updates)
Examples
--------
**Record postsynaptic spikes at explicit times:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom(tau_minus=20*u.ms)
>>> syn.init_state()
>>>
>>> # Record single spike at 10 ms
>>> n = syn.record_post_spike(1.0, t_spike_ms=10.0)
>>> print(f"Recorded {n} spike(s)")
Recorded 1 spike(s)
>>>
>>> # Record multiple simultaneous spikes
>>> n = syn.record_post_spike(3.0, t_spike_ms=20.0)
>>> print(f"Recorded {n} spike(s)")
Recorded 3 spike(s)
**Use current simulation time (automatic stamping):**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>> import brainstate as bst
>>>
>>> syn = bps.stdp_pl_synapse_hom()
>>> syn.init_state()
>>>
>>> with bst.environ.context(dt=0.1*u.ms):
... syn.record_post_spike() # Uses t_current + dt
"""
count = self._to_non_negative_int_count(multiplicity, name='post_spike')
if count == 0:
return 0
if t_spike_ms is None:
dt_ms = self._refresh_delay_if_needed()
t_value = self._current_time_ms() + dt_ms
else:
t_value = self._to_scalar_float(t_spike_ms, name='t_spike_ms')
for _ in range(count):
self._record_post_spike_at(float(t_value))
return count
def _get_post_history_times(self, t1_ms: float, t2_ms: float) -> list[float]:
t1_lim = float(t1_ms + _STDP_EPS)
t2_lim = float(t2_ms + _STDP_EPS)
selected = []
for t_post in self._post_hist_t:
if t_post >= t1_lim and t_post < t2_lim:
selected.append(t_post)
return selected
def _get_K_value(self, t_ms: float) -> float:
# Return trace strictly before t, matching ArchivingNode::get_K_value.
for idx in range(len(self._post_hist_t) - 1, -1, -1):
t_post = self._post_hist_t[idx]
if (t_ms - t_post) > _STDP_EPS:
return self._post_hist_kminus[idx] * math.exp((t_post - t_ms) / self.tau_minus)
return 0.0
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Initialize synapse state variables to default values.
Resets all mutable state to initial conditions, including:
- ``weight``: Baseline synaptic weight (inherited from ``static_synapse``)
- ``Kplus``: Presynaptic eligibility trace to ``_Kplus0``
- ``t_lastspike``: Last presynaptic spike time to ``_t_lastspike0`` (default ``0.0``)
- Postsynaptic spike history buffer (cleared via ``clear_post_history()``)
- Event delivery queue (inherited from ``static_synapse``)
Parameters
----------
batch_size : int, optional
Batch size for vectorized state initialization. Currently unused;
this synapse operates in scalar mode only. Default: ``None``.
**kwargs : dict, optional
Additional keyword arguments (unused; provided for API compatibility).
Notes
-----
- This method must be called before simulation begins
- Clears all postsynaptic spike history (calls ``clear_post_history()``)
- Does **not** reset common properties (``tau_plus``, ``lambda_``, ``alpha``, ``mu``)
- Presynaptic trace is reset to initial value set via constructor or ``set()``
See Also
--------
clear_post_history : Clear postsynaptic spike history only
set : Update parameters and initial state values
"""
del batch_size, kwargs
super().init_state()
self.Kplus = float(self._Kplus0)
self.t_lastspike = float(self._t_lastspike0)
self.clear_post_history()
[docs]
def get(self) -> dict:
r"""Return current public parameters and mutable state.
Retrieves all NEST-compatible public parameters and per-connection state
variables as a dictionary. This method is used for introspection, logging,
and state serialization.
Returns
-------
dict
Dictionary mapping parameter/state names to their current values:
- ``'weight'``: Current synaptic weight (float)
- ``'delay'``: Synaptic delay in ms (float)
- ``'receptor_type'``: Receiver port ID (int)
- ``'tau_plus'``: Potentiation time constant in ms (float)
- ``'tau_minus'``: Depression time constant in ms (float)
- ``'lambda'``: Learning rate (float)
- ``'alpha'``: Depression scaling factor (float)
- ``'mu'``: Power-law exponent (float)
- ``'Kplus'``: Current presynaptic trace value (float)
- ``'synapse_model'``: Model identifier (``'stdp_pl_synapse_hom'``)
Notes
-----
- All saiunit ``Quantity`` values are converted to Python floats (SI units)
- Internal state (``t_lastspike``, postsynaptic history) is **not** included
- The returned dictionary can be used with ``set(**params)`` for state restoration
- Key names match NEST conventions (``'lambda'`` instead of ``'lambda_'``)
See Also
--------
set : Update parameters and state from dictionary
init_state : Reset state to initial values
Examples
--------
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom(
... weight=1.5,
... tau_plus=20*u.ms,
... lambda_=0.1,
... )
>>> syn.init_state()
>>>
>>> params = syn.get()
>>> print(params['weight'])
1.5
>>> print(params['lambda'])
0.1
>>> print(params['synapse_model'])
stdp_pl_synapse_hom
"""
params = super().get()
params['tau_plus'] = float(self.tau_plus)
params['tau_minus'] = float(self.tau_minus)
params['lambda'] = float(self.lambda_)
params['alpha'] = float(self.alpha)
params['mu'] = float(self.mu)
params['Kplus'] = float(self.Kplus)
params['synapse_model'] = 'stdp_pl_synapse_hom'
return params
[docs]
def check_synapse_params(self, syn_spec: Mapping[str, object] | None):
r"""Validate connect-time synapse parameter specification.
Enforces NEST's homogeneous-property semantics by rejecting attempts to
override common model properties (``tau_plus``, ``lambda``, ``alpha``, ``mu``)
in per-connection synapse specifications.
In NEST, homogeneous models share plasticity parameters across all connections,
while per-connection properties (``weight``, ``Kplus``) can vary. This method
prevents accidental overrides that would violate this contract.
Parameters
----------
syn_spec : Mapping[str, object] or None
Synapse parameter specification dictionary, typically provided in
``Connect(..., syn_spec={...})`` calls. If ``None``, no validation is performed.
Raises
------
ValueError
If ``syn_spec`` contains any of the disallowed common properties:
``'tau_plus'``, ``'lambda'``, ``'alpha'``, ``'mu'``.
Notes
-----
- Allowed per-connection keys: ``'weight'``, ``'delay'``, ``'receptor_type'``, ``'Kplus'``
- Disallowed common-property keys: ``'tau_plus'``, ``'lambda'``, ``'alpha'``, ``'mu'``
- To change common properties, use ``set(tau_plus=..., lambda_=..., ...)`` on
the model instance, or NEST-style ``SetDefaults()`` / ``CopyModel()`` APIs
- This check is performed automatically during connection establishment
See Also
--------
set : Update model parameters (common and per-connection)
Examples
--------
**Valid per-connection specification:**
.. code-block:: python
>>> import brainpy.state as bps
>>>
>>> syn = bps.stdp_pl_synapse_hom(lambda_=0.1)
>>>
>>> # Allowed: per-connection properties
>>> syn.check_synapse_params({'weight': 2.0, 'Kplus': 0.5}) # OK
**Invalid common-property override:**
.. code-block:: python
>>> import brainpy.state as bps
>>>
>>> syn = bps.stdp_pl_synapse_hom(lambda_=0.1)
>>>
>>> # Disallowed: common property in connection spec
>>> try:
... syn.check_synapse_params({'lambda': 0.05})
... except ValueError as e:
... print(e)
lambda cannot be specified in connect-time synapse parameters...
"""
if syn_spec is None:
return
disallowed = ('tau_plus', 'lambda', 'alpha', 'mu')
for key in disallowed:
if key in syn_spec:
raise ValueError(
f'{key} cannot be specified in connect-time synapse parameters '
'for stdp_pl_synapse_hom; set common properties on the model '
'itself (for example via CopyModel()/SetDefaults()).'
)
[docs]
def set(
self,
*,
weight: ArrayLike | object = _UNSET,
delay: ArrayLike | object = _UNSET,
receptor_type: ArrayLike | object = _UNSET,
tau_plus: ArrayLike | object = _UNSET,
tau_minus: ArrayLike | object = _UNSET,
lambda_: ArrayLike | object = _UNSET,
alpha: ArrayLike | object = _UNSET,
mu: ArrayLike | object = _UNSET,
Kplus: ArrayLike | object = _UNSET,
post: object = _UNSET,
):
r"""Set NEST-style public parameters and mutable state.
Updates model parameters (common properties and per-connection state) with
validation. This method supports partial updates—only specified parameters
are modified.
Parameters
----------
weight : ArrayLike or sentinel, optional
New synaptic weight. Scalar float or array-like. Must be non-negative.
If ``_UNSET``, current value is preserved.
delay : ArrayLike or sentinel, optional
New synaptic delay in ms. Must be ``> 0``. If ``_UNSET``, current value
is preserved.
receptor_type : int or sentinel, optional
New receiver port ID (non-negative integer). If ``_UNSET``, current value
is preserved.
tau_plus : ArrayLike or sentinel, optional
New potentiation time constant in ms. Must be ``> 0``. If ``_UNSET``,
current value is preserved.
tau_minus : ArrayLike or sentinel, optional
New depression time constant in ms. Must be ``> 0`` (not enforced).
If ``_UNSET``, current value is preserved.
lambda_ : ArrayLike or sentinel, optional
New learning rate. Must be non-negative. If ``_UNSET``, current value
is preserved.
alpha : ArrayLike or sentinel, optional
New depression scaling factor. Must be non-negative (not enforced).
If ``_UNSET``, current value is preserved.
mu : ArrayLike or sentinel, optional
New power-law exponent. If ``_UNSET``, current value is preserved.
Kplus : ArrayLike or sentinel, optional
New presynaptic trace value. Must be non-negative (not enforced).
Updates both ``self.Kplus`` (current state) and ``self._Kplus0``
(initial value for ``init_state()``). If ``_UNSET``, current value
is preserved.
post : object or sentinel, optional
New default receiver object. If ``_UNSET``, current value is preserved.
Raises
------
ValueError
If ``tau_plus`` is provided and ``<= 0``.
ValueError
If any parameter is not a scalar, not finite, or violates type constraints.
Notes
-----
- All parameters are optional; only provided values are updated
- Parameter validation is performed before any state is modified
- Setting ``Kplus`` updates both current state and initial-value storage
- Common properties (``tau_plus``, ``lambda_``, ``alpha``, ``mu``) should
typically be set at model creation, not per-connection
- This method does **not** clear postsynaptic spike history or reset
``t_lastspike``
See Also
--------
get : Retrieve current parameter values
init_state : Reset state to initial values
Examples
--------
**Update learning rate and weight:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom(weight=1.0, lambda_=0.1)
>>> syn.init_state()
>>>
>>> syn.set(weight=2.0, lambda_=0.05)
>>> print(syn.get()['weight'])
2.0
>>> print(syn.get()['lambda'])
0.05
**Update time constants:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom()
>>> syn.set(tau_plus=15*u.ms, tau_minus=25*u.ms)
>>> print(syn.tau_plus)
15.0
>>> print(syn.tau_minus)
25.0
"""
new_tau_plus = (
self.tau_plus
if tau_plus is _UNSET
else self._to_scalar_time_ms(tau_plus, name='tau_plus')
)
self._validate_tau_plus(float(new_tau_plus))
new_tau_minus = (
self.tau_minus
if tau_minus is _UNSET
else self._to_scalar_time_ms(tau_minus, name='tau_minus')
)
new_lambda = (
self.lambda_
if lambda_ is _UNSET
else self._to_scalar_float(lambda_, name='lambda')
)
new_alpha = self.alpha if alpha is _UNSET else self._to_scalar_float(alpha, name='alpha')
new_mu = self.mu if mu is _UNSET else self._to_scalar_float(mu, name='mu')
new_Kplus = self.Kplus if Kplus is _UNSET else self._to_scalar_float(Kplus, name='Kplus')
super_kwargs = {}
if weight is not _UNSET:
super_kwargs['weight'] = self._normalize_scalar_weight(weight)
if delay is not _UNSET:
super_kwargs['delay'] = delay
if receptor_type is not _UNSET:
super_kwargs['receptor_type'] = receptor_type
if post is not _UNSET:
super_kwargs['post'] = post
if super_kwargs:
super().set(**super_kwargs)
self.tau_plus = float(new_tau_plus)
self.tau_minus = float(new_tau_minus)
self.lambda_ = float(new_lambda)
self.alpha = float(new_alpha)
self.mu = float(new_mu)
self.Kplus = float(new_Kplus)
self._Kplus0 = float(self.Kplus)
[docs]
def send(
self,
multiplicity: ArrayLike = 1.0,
*,
post=None,
receptor_type: ArrayLike | None = None,
) -> bool:
r"""Schedule one outgoing event with NEST ``stdp_pl_synapse_hom`` dynamics.
Processes a presynaptic spike event by applying power-law STDP weight updates
and scheduling the weighted event for delayed delivery to the postsynaptic
neuron. This method implements the exact update sequence from NEST
``models/stdp_pl_synapse_hom.h::send()``.
**Update Sequence:**
1. **Compute spike timestamp:** :math:`t_{\mathrm{spike}} = t_{\mathrm{current}} + dt`
2. **Facilitation (Potentiation):** For each postsynaptic spike :math:`t_{\mathrm{post}}`
in the causal window :math:`(t_{\mathrm{last}} - d,\, t_{\mathrm{spike}} - d]`:
- Back-propagate presynaptic trace: :math:`K^+_{\mathrm{eff}} = K^+ \exp((t_{\mathrm{last}} - (t_{\mathrm{post}} + d))/\tau_+)`
- Apply potentiation: :math:`w \leftarrow w + \lambda w^\mu K^+_{\mathrm{eff}}`
3. **Depression:** Retrieve postsynaptic trace :math:`K^-_{\mathrm{eff}}` at
:math:`t_{\mathrm{spike}} - d` and apply depression:
- :math:`w \leftarrow w - \alpha \lambda w K^-_{\mathrm{eff}}`
- Clip to non-negative: :math:`w \leftarrow \max(w, 0)`
4. **Event Scheduling:** Schedule weighted event :math:`w_{\mathrm{eff}} = w \times \mathrm{multiplicity}`
for delivery at :math:`t_{\mathrm{delivery}} = t_{\mathrm{spike}} + \mathrm{delay}`
5. **Presynaptic Trace Update:** :math:`K^+ \leftarrow K^+ \exp((t_{\mathrm{last}} - t_{\mathrm{spike}})/\tau_+) + 1`
6. **Timestamp Update:** :math:`t_{\mathrm{last}} \leftarrow t_{\mathrm{spike}}`
Parameters
----------
multiplicity : ArrayLike, optional
Presynaptic spike multiplicity (scalar float, typically ``1.0``).
If zero or negative, no event is scheduled and the method returns ``False``.
Default: ``1.0``.
post : object or None, optional
Receiver object (typically a neuron or neuron group). If ``None``, uses
the default receiver set in the constructor or via ``set(post=...)``.
Default: ``None``.
receptor_type : ArrayLike or None, optional
Receiver port identifier (non-negative integer). If ``None``, uses
``self.receptor_type``. Default: ``None``.
Returns
-------
bool
``True`` if an event was scheduled, ``False`` if ``multiplicity`` was zero.
Raises
------
ValueError
If ``receptor_type`` is provided but not a valid non-negative integer.
RuntimeError
If no receiver is available (``post`` is ``None`` and no default receiver
is set).
Notes
-----
- The method uses **on-grid spike timing**: spike time is :math:`t + dt`,
ignoring precise sub-step offsets
- Dendritic delay :math:`d` shifts the STDP causal window but does **not**
affect event delivery time (delivery delay is separate)
- Weight updates are applied **before** event scheduling, so the delivered
event reflects the updated weight
- Presynaptic trace is updated **after** STDP computation
- Postsynaptic spike history must be maintained externally via
``record_post_spike()`` or ``update(post_spike=...)``
See Also
--------
update : Combined pre/post spike processing with automatic history management
record_post_spike : Manually add postsynaptic spikes to history buffer
Examples
--------
**Standalone presynaptic spike processing:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom(weight=1.0, lambda_=0.1, mu=0.4)
>>> syn.init_state()
>>>
>>> # Record postsynaptic spike at 10 ms
>>> syn.record_post_spike(t_spike_ms=10.0)
>>>
>>> # Process presynaptic spike at 11 ms (causally follows post)
>>> success = syn.send(1.0)
>>> print(f"Event scheduled: {success}")
Event scheduled: True
>>> print(f"Updated weight: {syn.weight:.4f}")
Updated weight: 1.0xxx
**With explicit receiver:**
.. code-block:: python
>>> import brainpy.state as bps
>>>
>>> class DummyReceiver:
... def receive(self, weight, port, event_type):
... print(f"Received {weight} on port {port}")
>>>
>>> syn = bps.stdp_pl_synapse_hom()
>>> syn.init_state()
>>> receiver = DummyReceiver()
>>>
>>> syn.send(1.0, post=receiver, receptor_type=0)
True
"""
if not self._is_nonzero(multiplicity):
return False
dt_ms = self._refresh_delay_if_needed()
current_step = self._curr_step(dt_ms)
# NEST uses on-grid event stamps in this model.
t_spike = self._current_time_ms() + dt_ms
dendritic_delay = float(self.delay)
# Facilitation due to postsynaptic spikes in
# (t_lastspike - dendritic_delay, t_spike - dendritic_delay].
t1 = self.t_lastspike - dendritic_delay
t2 = t_spike - dendritic_delay
for t_post in self._get_post_history_times(t1, t2):
minus_dt = self.t_lastspike - (t_post + dendritic_delay)
assert minus_dt < (-1.0 * _STDP_EPS)
kplus_term = self.Kplus * math.exp(minus_dt / self.tau_plus)
self.weight = float(self._facilitate(float(self.weight), float(kplus_term)))
# Depression due to current presynaptic spike.
kminus_value = self._get_K_value(t_spike - dendritic_delay)
self.weight = float(self._depress(float(self.weight), float(kminus_value)))
receiver = self._resolve_receiver(post)
rport = self.receptor_type if receptor_type is None else self._to_receptor_type(receptor_type)
weighted_payload = multiplicity * float(self.weight)
delivery_step = int(current_step + int(self._delay_steps))
self._queue[delivery_step].append((receiver, weighted_payload, int(rport), 'spike'))
self.Kplus = float(self.Kplus * math.exp((self.t_lastspike - t_spike) / self.tau_plus) + 1.0)
self.t_lastspike = float(t_spike)
return True
[docs]
def update(
self,
pre_spike: ArrayLike = 0.0,
*,
post_spike: ArrayLike = 0.0,
post=None,
receptor_type: ArrayLike | None = None,
) -> int:
r"""Deliver due events, update postsynaptic history, then process presynaptic spikes.
Main update method for standalone STDP simulation. This method orchestrates
the complete synaptic update cycle in three phases:
1. **Event Delivery:** Deliver all events scheduled for the current time step
to the postsynaptic receiver
2. **Postsynaptic History Update:** Record incoming postsynaptic spikes into
the internal STDP history buffer
3. **Presynaptic Spike Processing:** Apply STDP weight updates and schedule
new events via ``send()``
This ordering matches NEST's event-driven simulation semantics, where
postsynaptic spike history is updated before processing presynaptic spikes
arriving in the same time step.
Parameters
----------
pre_spike : ArrayLike, optional
Presynaptic spike count (non-negative integer or float). Summed with
any registered current/delta inputs before processing. If zero, no
presynaptic spike is processed. Default: ``0.0``.
post_spike : ArrayLike, optional
Postsynaptic spike count (non-negative integer or float). Recorded into
the internal STDP history buffer at time :math:`t_{\mathrm{current}} + dt`.
Default: ``0.0``.
post : object or None, optional
Receiver object for event delivery. If ``None``, uses the default receiver
set in the constructor or via ``set(post=...)``. Default: ``None``.
receptor_type : ArrayLike or None, optional
Receiver port identifier (non-negative integer). If ``None``, uses
``self.receptor_type``. Default: ``None``.
Returns
-------
int
Number of events delivered to the postsynaptic receiver during this step.
Raises
------
ValueError
If ``post_spike`` is not a scalar, not finite, negative, or not close to
an integer value.
ValueError
If ``receptor_type`` is provided but not a valid non-negative integer.
RuntimeError
If a presynaptic spike is triggered but no receiver is available.
Notes
-----
- The method uses **on-grid spike timing**: spikes are stamped at
:math:`t_{\mathrm{current}} + dt`
- Presynaptic input is accumulated from three sources:
1. ``pre_spike`` parameter
2. ``current_inputs`` (registered via ``add_current_input()``)
3. ``delta_inputs`` (registered via ``add_delta_input()``)
- Multiple postsynaptic spikes at the same time (``post_spike > 1``) are
recorded sequentially with trace updates between each spike
- This method is typically called once per time step in a simulation loop
See Also
--------
send : Presynaptic spike processing and STDP weight updates
record_post_spike : Manually record postsynaptic spikes
add_current_input : Register input sources for presynaptic spike accumulation
Examples
--------
**Basic STDP simulation loop:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>> import brainstate as bst
>>>
>>> syn = bps.stdp_pl_synapse_hom(
... weight=1.0,
... tau_plus=20*u.ms,
... tau_minus=20*u.ms,
... lambda_=0.1,
... mu=0.4,
... )
>>> syn.init_state()
>>>
>>> with bst.environ.context(dt=1.0*u.ms):
... # Pre-before-post pairing (potentiation)
... for step in range(5):
... bst.environ.set_t(step * 1.0)
... pre = 1.0 if step == 0 else 0.0
... post = 1.0 if step == 2 else 0.0
... syn.update(pre_spike=pre, post_spike=post)
... print(f"Weight after potentiation: {syn.weight:.4f}")
Weight after potentiation: 1.0xxx
**With input accumulation:**
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>>
>>> syn = bps.stdp_pl_synapse_hom()
>>> syn.init_state()
>>>
>>> # Register input source
>>> syn.add_current_input('pre_neurons', lambda: 0.5)
>>>
>>> # Update with explicit + accumulated input
>>> n_delivered = syn.update(pre_spike=0.5) # Total: 0.5 + 0.5 = 1.0
>>> print(f"Delivered {n_delivered} event(s)")
Delivered 1 event(s)
"""
dt_ms = self._refresh_delay_if_needed()
step = self._curr_step(dt_ms)
delivered = self._deliver_due_events(step)
post_count = self._to_non_negative_int_count(post_spike, name='post_spike')
if post_count > 0:
t_post = self._current_time_ms() + dt_ms
for _ in range(post_count):
self._record_post_spike_at(float(t_post))
total_pre = self.sum_current_inputs(pre_spike)
total_pre = self.sum_delta_inputs(total_pre)
if self._is_nonzero(total_pre):
self.send(total_pre, post=post, receptor_type=receptor_type)
return delivered