import math
from typing import Any
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike
from ._base import NESTSynapse
# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
__all__ = [
'clopath_synapse',
]
class clopath_synapse(NESTSynapse):
r"""NEST-compatible voltage-based STDP synapse following the Clopath plasticity rule.
This synapse model implements the voltage-based spike-timing-dependent plasticity (STDP)
rule described by Clopath et al. (2010). Unlike traditional pair-based STDP, weight updates
depend on postsynaptic membrane voltage traces archived by the target neuron, enabling
voltage-dependent learning rules that capture homeostatic regulation and triplet interactions.
**1. Model Overview**
The Clopath synapse is a connection-level model that modulates synaptic weights based on:
- Presynaptic spike times and a presynaptic trace :math:`\bar{x}(t)`
- Postsynaptic voltage-derived traces for long-term potentiation (LTP) and depression (LTD)
- Hard bounds :math:`[W_\text{min}, W_\text{max}]` on synaptic weight
The model requires the postsynaptic neuron to maintain voltage-dependent history buffers
(e.g., ``aeif_psc_delta_clopath`` in NEST).
**2. State Variables**
Each connection maintains:
- :math:`w` -- Current synaptic weight
- :math:`\bar{x}` -- Presynaptic trace (low-pass filtered spike train)
- :math:`t_\text{last}` -- Timestamp of most recent presynaptic spike (milliseconds)
- :math:`\tau_x` -- Time constant for presynaptic trace decay (milliseconds)
- :math:`W_\text{min}, W_\text{max}` -- Hard lower/upper weight bounds
**3. Plasticity Update Sequence**
On each presynaptic spike at time :math:`t` (with dendritic delay :math:`d` and previous
spike time :math:`t_\text{last}`), the following steps are executed in order:
**(a) Long-Term Potentiation (LTP):**
Retrieve all LTP history entries from the postsynaptic neuron in the interval
:math:`(t_\text{last} - d,\, t - d]`. For each entry with timestamp :math:`t_i` and
amplitude :math:`\text{dw}_i`:
.. math::
w \leftarrow \min\left(W_\text{max},\, w + \text{dw}_i \cdot \bar{x}
\exp\left(\frac{t_\text{last} - (t_i + d)}{\tau_x}\right)\right)
This facilitates the weight proportional to the presynaptic trace at the effective time
:math:`t_i + d`, accounting for exponential decay from :math:`t_\text{last}`.
**(b) Long-Term Depression (LTD):**
Query the postsynaptic LTD value at the effective time :math:`t - d`:
.. math::
w \leftarrow \max(W_\text{min},\, w - \text{dw}_\text{LTD}(t - d))
This depresses the weight by the current LTD trace magnitude, clamped to :math:`W_\text{min}`.
**(c) Event Emission:**
A spike event is generated with the updated weight and delay parameters.
**(d) Presynaptic Trace Update:**
The presynaptic trace is updated to account for the new spike:
.. math::
\bar{x} \leftarrow \bar{x} \exp\left(\frac{t_\text{last} - t}{\tau_x}\right)
+ \frac{1}{\tau_x}
**(e) Timestamp Update:**
The last spike time is updated: :math:`t_\text{last} \leftarrow t`.
**4. Mathematical Foundations**
The presynaptic trace :math:`\bar{x}(t)` is a low-pass filter of the presynaptic spike
train :math:`S_\text{pre}(t) = \sum_k \delta(t - t_k)`:
.. math::
\tau_x \frac{d\bar{x}}{dt} = -\bar{x} + S_\text{pre}(t)
At each spike time :math:`t_k`, the exact solution yields the jump condition:
.. math::
\bar{x}(t_k^+) = \bar{x}(t_k^-) e^{-(t_k - t_{k-1})/\tau_x} + \frac{1}{\tau_x}
This exact event-driven update is implemented in step (d).
**5. Postsynaptic Interface Requirements**
The target neuron must provide:
- ``get_LTP_history(t1, t2)`` or ``get_ltp_history(t1, t2)``:
Returns iterable of LTP events in :math:`(t_1, t_2]`
- ``get_LTD_value(t)`` or ``get_ltd_value(t)``:
Returns scalar LTD amplitude at time :math:`t`
Each LTP history entry must support extraction of:
- Time field: ``t_``, ``t``, ``time_ms``, or ``time``
- Weight change field: ``dw_``, ``dw``, ``delta_w``, or ``weight_change``
Supported entry formats:
- Object with attributes ``t_`` and ``dw_``
- Object with attributes ``t`` and ``dw``
- Dictionary with keys ``'t'``/``'t_'`` and ``'dw'``/``'dw_'``
- 2-tuple ``(t, dw)``
Parameters
----------
weight : float, optional
Initial synaptic weight (dimensionless). Must satisfy sign consistency with ``Wmin``
and ``Wmax``. Default: ``1.0``.
delay : float, optional
Dendritic propagation delay in milliseconds. Must be positive. Default: ``1.0``.
delay_steps : int, optional
Integer delay in simulation time steps for event delivery. Must be >= 1. Default: ``1``.
x_bar : float, optional
Initial presynaptic trace value (dimensionless). Typically initialized to ``0.0`` before
any spikes. Default: ``0.0``.
tau_x : float, optional
Time constant for presynaptic trace exponential decay (milliseconds). Must be positive
and non-zero. Controls the temporal window of LTP. Typical values: 10-20 ms. Default: ``15.0``.
Wmin : float, optional
Hard lower bound on synaptic weight (dimensionless). Must have same sign as ``weight``
according to NEST's internal sign checks. Default: ``0.0``.
Wmax : float, optional
Hard upper bound on synaptic weight (dimensionless). Must have same sign as ``weight``
according to NEST's internal sign checks. Default: ``100.0``.
t_last_spike_ms : float, optional
Timestamp of the most recent presynaptic spike (milliseconds). Initialized to ``0.0``
before the first spike. Default: ``0.0``.
name : str or None, optional
Optional identifier for this connection instance. Default: ``None``.
Parameter Mapping
-----------------
This table shows the correspondence between brainpy.state and NEST parameter names:
===================== ===================== ============= ===================
brainpy.state NEST Unit Description
===================== ===================== ============= ===================
``weight`` ``weight`` (unitless) Synaptic weight
``delay`` ``delay`` ms Dendritic delay
``delay_steps`` (runtime) steps Event delivery delay
``x_bar`` ``x_bar`` (unitless) Presynaptic trace
``tau_x`` ``tau_x`` ms Presynaptic time constant
``Wmin`` ``Wmin`` (unitless) Minimum weight
``Wmax`` ``Wmax`` (unitless) Maximum weight
``t_last_spike_ms`` (internal state) ms Last spike timestamp
===================== ===================== ============= ===================
Attributes
----------
HAS_DELAY : bool
Connection supports propagation delay. Always ``True``.
IS_PRIMARY : bool
Connection is a primary connection type. Always ``True``.
REQUIRES_CLOPATH_ARCHIVING : bool
Connection requires voltage trace archiving from postsynaptic neuron. Always ``True``.
SUPPORTS_HPC : bool
Model supports high-performance computing infrastructure. Always ``True``.
SUPPORTS_LBL : bool
Model supports label-based lookup. Always ``True``.
SUPPORTS_WFR : bool
Model supports waveform relaxation iteration. Always ``True``.
Raises
------
ValueError
If ``weight``, ``Wmin``, and ``Wmax`` do not satisfy sign consistency constraints.
NEST enforces: ``sign(weight) == sign(Wmin)`` and ``sign(weight) == sign(Wmax)``,
where sign tests use different comparison operators for min vs max bounds.
ValueError
If ``delay`` <= 0 or ``delay_steps`` < 1.
ValueError
If any parameter is non-finite (NaN or Inf).
ValueError
If ``tau_x`` is zero (division by zero in trace updates).
AttributeError
If target neuron does not provide required ``get_LTP_history`` and ``get_LTD_value``
methods during ``send()`` call.
See Also
--------
aeif_psc_delta_clopath : Adaptive exponential IF neuron with Clopath archiving (NEST)
hh_psc_alpha_clopath : Hodgkin-Huxley neuron with Clopath archiving (NEST)
stdp_synapse : Traditional pair-based STDP synapse
Notes
-----
**Implementation Details:**
- All internal computations use 64-bit floating point (``float64``) to match NEST precision.
- Precise sub-grid spike timing offsets are ignored; all spike times are treated as exact
multiples of the simulation time step.
- The update sequence strictly follows ``clopath_synapse::send()`` in NEST to ensure
numerical equivalence.
- Sign constraints use NEST's asymmetric comparison operators: ``Wmin`` uses ``>=`` vs ``<``
while ``Wmax`` uses ``>`` vs ``<=``.
**Biological Interpretation:**
The Clopath rule captures key experimental observations:
- LTP depends on presynaptic activity (spike trace :math:`\bar{x}`) and postsynaptic
depolarization (voltage-derived LTP trace).
- LTD depends on presynaptic spikes paired with postsynaptic voltage without strong
depolarization.
- The voltage dependence enables homeostatic regulation: neurons with high baseline firing
rates have reduced LTP, preventing runaway excitation.
- The model reproduces triplet STDP effects without explicit triplet terms.
**Computational Considerations:**
- Memory overhead scales with the number of LTP history entries archived by the postsynaptic
neuron (typically bounded by a sliding time window).
- For large fan-in networks, the LTP history query in step (a) may become a bottleneck.
Consider using sparse indexing or binned histograms for postsynaptic traces.
- The exponential decay calculations use ``math.exp`` for scalar operations. For vectorized
implementations, replace with ``jax.numpy.exp`` or equivalent.
References
----------
.. [1] Clopath, C., Büsing, L., Vasilaki, E., & Gerstner, W. (2010). Connectivity reflects
coding: a model of voltage-based STDP with homeostasis. *Nature Neuroscience*, 13(3),
344-352. DOI: 10.1038/nn.2479
.. [2] NEST Initiative (2024). NEST Simulator Documentation: clopath_synapse.
https://nest-simulator.readthedocs.io/
.. [3] NEST source code: ``models/clopath_synapse.h`` and ``models/clopath_synapse.cpp``.
https://github.com/nest/nest-simulator
Examples
--------
**Basic Usage:**
Create a Clopath synapse with default parameters:
.. code-block:: python
>>> import brainpy.state as bst
>>> synapse = bst.clopath_synapse(weight=0.5, tau_x=15.0, Wmin=0.0, Wmax=1.0)
>>> synapse.get_status()
{'weight': 0.5, 'tau_x': 15.0, 'Wmin': 0.0, 'Wmax': 1.0, ...}
**Simulating Presynaptic Spike Train:**
Assuming a postsynaptic neuron with Clopath archiving:
.. code-block:: python
>>> # Mock target neuron with required interface
>>> class ClopathNeuron:
... def get_ltp_history(self, t1, t2):
... # Return LTP events in (t1, t2]
... return [(10.5, 0.05), (12.3, 0.08)] # (time_ms, dw)
... def get_ltd_value(self, t):
... # Return LTD amplitude at time t
... return 0.02
>>> target = ClopathNeuron()
>>> synapse = bst.clopath_synapse(weight=1.0, tau_x=10.0, Wmin=0.0, Wmax=5.0)
>>> # Process spike train
>>> spike_times = [10.0, 20.0, 30.0]
>>> events = synapse.simulate_pre_spike_train(spike_times, target)
>>> print(f"Final weight: {synapse.weight:.3f}")
Final weight: 1.123
**Weight Evolution with Voltage-Dependent Plasticity:**
.. code-block:: python
>>> synapse = bst.clopath_synapse(weight=1.0, tau_x=15.0, Wmin=-2.0, Wmax=2.0)
>>> # Simulate pairing protocol: pre before post (LTP)
>>> for t in [10, 20, 30]:
... event = synapse.send(t_spike_ms=t, target=target)
>>> print(f"Weight after LTP protocol: {synapse.weight:.3f}")
Weight after LTP protocol: 1.450
**Sign Constraint Validation:**
.. code-block:: python
>>> # Valid: all same sign (positive)
>>> synapse = bst.clopath_synapse(weight=1.0, Wmin=0.0, Wmax=5.0)
>>> # Invalid: mixed signs
>>> try:
... synapse = bst.clopath_synapse(weight=1.0, Wmin=-1.0, Wmax=5.0)
... except ValueError as e:
... print(e)
Weight and Wmin must have same sign.
"""
__module__ = 'brainpy.state'
HAS_DELAY = True
IS_PRIMARY = True
REQUIRES_CLOPATH_ARCHIVING = True
SUPPORTS_HPC = True
SUPPORTS_LBL = True
SUPPORTS_WFR = True
def __init__(
self,
weight: ArrayLike = 1.0,
delay: ArrayLike = 1.0,
delay_steps: ArrayLike = 1,
x_bar: ArrayLike = 0.0,
tau_x: ArrayLike = 15.0,
Wmin: ArrayLike = 0.0,
Wmax: ArrayLike = 100.0,
t_last_spike_ms: ArrayLike = 0.0,
name: str | None = None,
):
super().__init__(in_size=1, name=name)
self.weight = self._to_float_scalar(weight, name='weight')
self.delay = self._validate_positive_delay(delay)
self.delay_steps = self._validate_delay_steps(delay_steps)
self.x_bar = self._to_float_scalar(x_bar, name='x_bar')
self.tau_x = self._to_float_scalar(tau_x, name='tau_x')
self.Wmin = self._to_float_scalar(Wmin, name='Wmin')
self.Wmax = self._to_float_scalar(Wmax, name='Wmax')
self.t_last_spike_ms = self._to_float_scalar(t_last_spike_ms, name='t_last_spike_ms')
self._check_weight_sign_constraints()
@property
def properties(self) -> dict[str, Any]:
r"""Return dictionary of connection model properties and capabilities.
Returns
-------
dict[str, Any]
Dictionary with keys:
- ``'has_delay'`` (bool): Connection supports propagation delay
- ``'is_primary'`` (bool): Connection is primary type
- ``'requires_clopath_archiving'`` (bool): Requires voltage trace archiving
- ``'supports_hpc'`` (bool): High-performance computing support
- ``'supports_lbl'`` (bool): Label-based lookup support
- ``'supports_wfr'`` (bool): Waveform relaxation support
"""
return {
'has_delay': self.HAS_DELAY,
'is_primary': self.IS_PRIMARY,
'requires_clopath_archiving': self.REQUIRES_CLOPATH_ARCHIVING,
'supports_hpc': self.SUPPORTS_HPC,
'supports_lbl': self.SUPPORTS_LBL,
'supports_wfr': self.SUPPORTS_WFR,
}
[docs]
def get_status(self) -> dict[str, Any]:
r"""Retrieve current connection state and parameter values.
Returns
-------
dict[str, Any]
Dictionary containing all connection parameters, state variables, and properties:
- ``'weight'`` (float): Current synaptic weight
- ``'delay'`` (float): Dendritic delay in milliseconds
- ``'delay_steps'`` (int): Integer delay in simulation steps
- ``'x_bar'`` (float): Current presynaptic trace value
- ``'tau_x'`` (float): Presynaptic trace time constant (ms)
- ``'Wmin'`` (float): Minimum weight bound
- ``'Wmax'`` (float): Maximum weight bound
- ``'t_last_spike_ms'`` (float): Last presynaptic spike time (ms)
- ``'size_of'`` (int): Memory footprint in bytes
- ``'has_delay'`` (bool): Delay support flag
- ``'is_primary'`` (bool): Primary connection flag
- ``'requires_clopath_archiving'`` (bool): Archiving requirement flag
- ``'supports_hpc'`` (bool): HPC support flag
- ``'supports_lbl'`` (bool): Label-based lookup flag
- ``'supports_wfr'`` (bool): Waveform relaxation flag
Notes
-----
This method provides NEST-compatible status retrieval. All values are returned as
Python native types (float, int, bool) rather than NumPy arrays.
"""
return {
'weight': float(self.weight),
'delay': float(self.delay),
'delay_steps': int(self.delay_steps),
'x_bar': float(self.x_bar),
'tau_x': float(self.tau_x),
'Wmin': float(self.Wmin),
'Wmax': float(self.Wmax),
't_last_spike_ms': float(self.t_last_spike_ms),
'size_of': int(self.__sizeof__()),
'has_delay': self.HAS_DELAY,
'is_primary': self.IS_PRIMARY,
'requires_clopath_archiving': self.REQUIRES_CLOPATH_ARCHIVING,
'supports_hpc': self.SUPPORTS_HPC,
'supports_lbl': self.SUPPORTS_LBL,
'supports_wfr': self.SUPPORTS_WFR,
}
[docs]
def set_status(self, status: dict[str, Any] | None = None, **kwargs):
r"""Update connection parameters and state variables.
Parameters
----------
status : dict[str, Any] or None, optional
Dictionary of parameter name-value pairs to update. Supported keys: ``'weight'``,
``'delay'``, ``'delay_steps'``, ``'x_bar'``, ``'tau_x'``, ``'Wmin'``, ``'Wmax'``,
``'t_last_spike_ms'``. Default: ``None``.
**kwargs
Additional parameter updates as keyword arguments. These are merged with ``status``
dictionary; keyword arguments take precedence.
Raises
------
ValueError
If updated parameters violate sign consistency constraints (``weight``, ``Wmin``,
``Wmax`` must all have compatible signs).
ValueError
If ``delay`` <= 0 or ``delay_steps`` < 1.
ValueError
If any parameter value is non-finite (NaN or Inf).
Notes
-----
This method provides NEST-compatible parameter setting. Sign constraints are re-checked
after all updates are applied. If multiple parameters are updated together, validation
occurs atomically after all changes.
Examples
--------
Update single parameter:
.. code-block:: python
>>> synapse = bst.clopath_synapse(weight=1.0)
>>> synapse.set_status(weight=0.5)
>>> synapse.get_status()['weight']
0.5
Update multiple parameters:
.. code-block:: python
>>> synapse.set_status({'weight': 2.0, 'tau_x': 20.0})
>>> synapse.set_status(Wmin=0.0, Wmax=5.0)
"""
updates = {}
if status is not None:
updates.update(status)
updates.update(kwargs)
if 'weight' in updates:
self.weight = self._to_float_scalar(updates['weight'], name='weight')
if 'delay' in updates:
self.delay = self._validate_positive_delay(updates['delay'])
if 'delay_steps' in updates:
self.delay_steps = self._validate_delay_steps(updates['delay_steps'])
if 'x_bar' in updates:
self.x_bar = self._to_float_scalar(updates['x_bar'], name='x_bar')
if 'tau_x' in updates:
self.tau_x = self._to_float_scalar(updates['tau_x'], name='tau_x')
if 'Wmin' in updates:
self.Wmin = self._to_float_scalar(updates['Wmin'], name='Wmin')
if 'Wmax' in updates:
self.Wmax = self._to_float_scalar(updates['Wmax'], name='Wmax')
if 't_last_spike_ms' in updates:
self.t_last_spike_ms = self._to_float_scalar(updates['t_last_spike_ms'], name='t_last_spike_ms')
self._check_weight_sign_constraints()
[docs]
def get(self, key: str = 'status'):
r"""Retrieve connection status or specific parameter value.
Parameters
----------
key : str, optional
Key to retrieve. Use ``'status'`` for full status dictionary, or specify a parameter
name (e.g., ``'weight'``, ``'tau_x'``, ``'Wmin'``). Default: ``'status'``.
Returns
-------
dict[str, Any] or float or int or bool
If ``key == 'status'``, returns full status dictionary. Otherwise returns the
requested parameter value with type matching the parameter (float, int, or bool).
Raises
------
KeyError
If ``key`` is not ``'status'`` and does not match any parameter or property name
in the status dictionary.
Examples
--------
.. code-block:: python
>>> synapse = bst.clopath_synapse(weight=1.5, tau_x=12.0)
>>> synapse.get('weight')
1.5
>>> synapse.get('tau_x')
12.0
>>> status = synapse.get('status')
>>> status['Wmax']
100.0
"""
if key == 'status':
return self.get_status()
status = self.get_status()
if key in status:
return status[key]
raise KeyError(f'Unsupported key "{key}" for clopath_synapse.get().')
[docs]
def set_weight(self, weight: ArrayLike):
r"""Set synaptic weight value.
Parameters
----------
weight : float or array-like
New synaptic weight value (scalar). Must be finite and satisfy sign consistency
with ``Wmin`` and ``Wmax``.
Raises
------
ValueError
If ``weight`` is non-scalar, non-finite, or violates sign constraints.
Notes
-----
This is a convenience method equivalent to ``set_status(weight=...)``, but does not
re-check sign constraints (assumes they were satisfied during initialization).
"""
self.weight = self._to_float_scalar(weight, name='weight')
[docs]
def set_delay(self, delay: ArrayLike):
r"""Set dendritic propagation delay.
Parameters
----------
delay : float or array-like
New delay in milliseconds (scalar). Must be positive.
Raises
------
ValueError
If ``delay`` is non-scalar, non-finite, or <= 0.
"""
self.delay = self._validate_positive_delay(delay)
[docs]
def set_delay_steps(self, delay_steps: ArrayLike):
r"""Set integer delay in simulation time steps.
Parameters
----------
delay_steps : int or array-like
New delay in time steps (scalar integer). Must be >= 1.
Raises
------
ValueError
If ``delay_steps`` is non-scalar, non-finite, non-integer, or < 1.
"""
self.delay_steps = self._validate_delay_steps(delay_steps)
[docs]
def send(
self,
t_spike_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> dict[str, Any]:
r"""Process one presynaptic spike with Clopath plasticity and return spike event payload.
This method implements the core plasticity update sequence:
1. Query postsynaptic LTP history in the interval since last presynaptic spike
2. Apply facilitation (LTP) for each history entry using decayed presynaptic trace
3. Apply depression (LTD) at the current effective spike time
4. Update presynaptic trace and last spike timestamp
5. Return spike event with updated weight
Parameters
----------
t_spike_ms : float or array-like
Current presynaptic spike time in milliseconds (scalar). Must be finite.
target : object
Postsynaptic neuron or target object. Must provide ``get_LTP_history(t1, t2)``
(or ``get_ltp_history``) and ``get_LTD_value(t)`` (or ``get_ltd_value``) methods.
receptor_type : int or array-like, optional
Receptor port index on target neuron (scalar integer). Default: ``0``.
multiplicity : float or array-like, optional
Spike event multiplicity (scalar, >= 0). Used for batch spike processing.
Default: ``1.0``.
delay : float or array-like or None, optional
Override dendritic delay for this spike (milliseconds, scalar). If ``None``, uses
connection's stored ``self.delay``. Default: ``None``.
delay_steps : int or array-like or None, optional
Override integer delay in time steps for this event (scalar). If ``None``, uses
connection's stored ``self.delay_steps``. Default: ``None``.
Returns
-------
dict[str, Any]
Spike event payload dictionary with keys:
- ``'weight'`` (float): Updated synaptic weight after plasticity
- ``'delay'`` (float): Dendritic delay used (milliseconds)
- ``'delay_steps'`` (int): Integer delay in time steps
- ``'receptor_type'`` (int): Target receptor port index
- ``'multiplicity'`` (float): Spike multiplicity
- ``'t_spike_ms'`` (float): Presynaptic spike timestamp
Raises
------
ValueError
If ``tau_x`` is zero (division by zero in exponential decay calculations).
ValueError
If any parameter is non-scalar or non-finite.
ValueError
If ``delay`` <= 0, ``delay_steps`` < 1, or ``multiplicity`` < 0.
AttributeError
If ``target`` does not provide required ``get_LTP_history`` and ``get_LTD_value``
methods (or their lowercase variants).
ValueError
If any LTP history entry does not provide extractable time and weight change fields.
Notes
-----
**Update Sequence Details:**
Let :math:`t` = ``t_spike_ms``, :math:`d` = effective delay, :math:`t_\text{last}` =
``self.t_last_spike_ms``.
*Step 1: LTP Application*
For each LTP history entry :math:`(t_i, \text{dw}_i)` in :math:`(t_\text{last} - d, t - d]`:
.. math::
w \leftarrow \min(W_\text{max},\, w + \text{dw}_i \cdot \bar{x}
\exp((t_\text{last} - (t_i + d)) / \tau_x))
*Step 2: LTD Application*
.. math::
w \leftarrow \max(W_\text{min},\, w - \text{dw}_\text{LTD}(t - d))
*Step 3: Presynaptic Trace Update*
.. math::
\bar{x} \leftarrow \bar{x} \exp((t_\text{last} - t) / \tau_x) + 1 / \tau_x
*Step 4: Timestamp Update*
.. math::
t_\text{last} \leftarrow t
**Side Effects**
This method modifies connection state:
- ``self.weight``: Updated by LTP and LTD
- ``self.x_bar``: Updated with new spike contribution
- ``self.t_last_spike_ms``: Set to current spike time
**Performance Considerations:**
The LTP history query dominates runtime for neurons with many incoming connections.
Consider using bounded history buffers (sliding window) in the target neuron to limit
the number of entries returned.
Examples
--------
Process single spike:
.. code-block:: python
>>> class MockTarget:
... def get_ltp_history(self, t1, t2):
... return [(5.5, 0.1)] # One LTP event
... def get_ltd_value(self, t):
... return 0.02
>>> target = MockTarget()
>>> synapse = bst.clopath_synapse(weight=1.0, tau_x=10.0, Wmin=0.0, Wmax=5.0)
>>> event = synapse.send(t_spike_ms=10.0, target=target)
>>> print(f"Updated weight: {event['weight']:.3f}")
Updated weight: 1.023
Process spike with custom delay:
.. code-block:: python
>>> event = synapse.send(t_spike_ms=20.0, target=target, delay=2.5)
>>> print(f"Delay used: {event['delay']} ms")
Delay used: 2.5 ms
"""
t_spike = self._to_float_scalar(t_spike_ms, name='t_spike_ms')
if self.tau_x == 0.0:
raise ValueError('tau_x must be non-zero.')
dendritic_delay = self.delay if delay is None else self._validate_positive_delay(delay)
event_delay_steps = (
self.delay_steps
if delay_steps is None
else self._validate_delay_steps(delay_steps)
)
ltp_entries = self._get_ltp_history(
target,
self.t_last_spike_ms - dendritic_delay,
t_spike - dendritic_delay,
)
for entry in ltp_entries:
t_hist, dw = self._extract_history_entry(entry)
minus_dt = self.t_last_spike_ms - (t_hist + dendritic_delay)
self.weight = self._facilitate(
self.weight,
dw,
self.x_bar * math.exp(minus_dt / self.tau_x),
)
ltd_dw = self._get_ltd_value(target, t_spike - dendritic_delay)
self.weight = self._depress(self.weight, ltd_dw)
event = {
'weight': float(self.weight),
'delay': float(dendritic_delay),
'delay_steps': int(event_delay_steps),
'receptor_type': self._to_int_scalar(receptor_type, name='receptor_type'),
'multiplicity': self._validate_multiplicity(multiplicity),
't_spike_ms': float(t_spike),
}
self.x_bar = self.x_bar * math.exp((self.t_last_spike_ms - t_spike) / self.tau_x) + 1.0 / self.tau_x
self.t_last_spike_ms = t_spike
return event
[docs]
def to_spike_event(
self,
t_spike_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> dict[str, Any]:
r"""Alias for ``send()`` method with identical semantics.
This method provides an alternative name for spike event generation, maintaining
compatibility with different naming conventions. All parameters and return values
are identical to ``send()``.
See Also
--------
send : Primary spike processing method with full documentation
"""
return self.send(
t_spike_ms=t_spike_ms,
target=target,
receptor_type=receptor_type,
multiplicity=multiplicity,
delay=delay,
delay_steps=delay_steps,
)
[docs]
def simulate_pre_spike_train(
self,
spike_times_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> list[dict[str, Any]]:
r"""Process a sequence of presynaptic spikes and return event payloads for each.
This method sequentially processes multiple presynaptic spikes, updating connection
state (weight, presynaptic trace, last spike time) after each spike. The plasticity
updates are cumulative: each spike's LTP/LTD application affects the weight seen by
subsequent spikes.
Parameters
----------
spike_times_ms : array-like
Array of presynaptic spike times in milliseconds. Values are converted to 1-D float64
array and processed in order. Must contain finite values.
target : object
Postsynaptic neuron providing ``get_LTP_history`` and ``get_LTD_value`` methods.
receptor_type : int or array-like, optional
Receptor port index (scalar). Default: ``0``.
multiplicity : float or array-like, optional
Spike multiplicity for all events (scalar). Default: ``1.0``.
delay : float or array-like or None, optional
Dendritic delay override (milliseconds). If ``None``, uses ``self.delay``.
Default: ``None``.
delay_steps : int or array-like or None, optional
Integer delay override (time steps). If ``None``, uses ``self.delay_steps``.
Default: ``None``.
Returns
-------
list[dict[str, Any]]
List of spike event payloads, one per input spike time. Each event dictionary
contains the same keys as returned by ``send()``: ``'weight'``, ``'delay'``,
``'delay_steps'``, ``'receptor_type'``, ``'multiplicity'``, ``'t_spike_ms'``.
Raises
------
ValueError
If any spike time is non-finite or if ``tau_x`` is zero.
AttributeError
If ``target`` does not provide required methods.
Notes
-----
**Ordering Effects:**
Spike times are processed in array order (not necessarily sorted by time). For
biologically realistic simulations, ensure ``spike_times_ms`` is sorted in ascending
order. Out-of-order spikes may produce unphysical weight trajectories due to incorrect
exponential decay calculations.
**State Evolution:**
After processing spike train ``[t1, t2, ..., tn]``, the connection state reflects:
- ``self.weight``: Cumulative effect of all LTP/LTD updates
- ``self.x_bar``: Presynaptic trace at time ``tn``
- ``self.t_last_spike_ms``: Set to ``tn``
**Memory Considerations:**
The returned event list stores a separate dictionary for each spike. For very long
spike trains (>10^6 spikes), consider processing in batches to reduce memory overhead.
Examples
--------
Process spike train and track weight evolution:
.. code-block:: python
>>> class MockTarget:
... def get_ltp_history(self, t1, t2):
... # Return one LTP event per query interval
... if t2 > t1:
... return [((t1 + t2) / 2, 0.05)]
... return []
... def get_ltd_value(self, t):
... return 0.01
>>> target = MockTarget()
>>> synapse = bst.clopath_synapse(weight=1.0, tau_x=10.0, Wmin=0.0, Wmax=3.0)
>>> spike_times = [10.0, 20.0, 30.0, 40.0, 50.0]
>>> events = synapse.simulate_pre_spike_train(spike_times, target)
>>> weights = [evt['weight'] for evt in events]
>>> print(f"Weight trajectory: {weights}")
Weight trajectory: [1.045, 1.083, 1.115, 1.142, 1.165]
Verify presynaptic trace evolution:
.. code-block:: python
>>> synapse = bst.clopath_synapse(weight=1.0, tau_x=15.0)
>>> events = synapse.simulate_pre_spike_train([0.0, 15.0, 30.0], target)
>>> print(f"Final x_bar: {synapse.x_bar:.3f}")
Final x_bar: 0.091
"""
dftype = brainstate.environ.dftype()
times = np.asarray(u.math.asarray(spike_times_ms), dtype=dftype).reshape(-1)
events = []
for t in times:
events.append(
self.send(
t_spike_ms=float(t),
target=target,
receptor_type=receptor_type,
multiplicity=multiplicity,
delay=delay,
delay_steps=delay_steps,
)
)
return events
def _check_weight_sign_constraints(self):
r"""Validate sign consistency between weight and bounds using NEST's exact sign tests.
Raises
------
ValueError
If ``weight`` and ``Wmin`` have incompatible signs (using >= vs < comparison).
ValueError
If ``weight`` and ``Wmax`` have incompatible signs (using > vs <= comparison).
Notes
-----
NEST uses asymmetric sign tests:
- ``Wmin`` test: ``(weight >= 0) - (weight < 0)`` must equal ``(Wmin >= 0) - (Wmin < 0)``
- ``Wmax`` test: ``(weight > 0) - (weight <= 0)`` must equal ``(Wmax > 0) - (Wmax <= 0)``
This asymmetry means zero is treated differently for min vs max bounds. For example,
``weight=0.0``, ``Wmin=0.0``, ``Wmax=0.0`` is valid, but ``weight=0.0``, ``Wmin=-1.0``
is invalid.
"""
# Keep sign checks exactly as in NEST clopath_synapse::set_status.
if self._sign_like_wmin(self.weight) != self._sign_like_wmin(self.Wmin):
raise ValueError('Weight and Wmin must have same sign.')
if self._sign_like_wmax(self.weight) != self._sign_like_wmax(self.Wmax):
raise ValueError('Weight and Wmax must have same sign.')
@staticmethod
def _sign_like_wmin(x: float) -> int:
r"""Compute NEST-compatible sign test for Wmin comparison (>= 0 vs < 0).
Returns -1 for negative, +1 for non-negative.
"""
return int((x >= 0.0) - (x < 0.0))
@staticmethod
def _sign_like_wmax(x: float) -> int:
r"""Compute NEST-compatible sign test for Wmax comparison (> 0 vs <= 0).
Returns -1 for non-positive, +1 for positive.
"""
return int((x > 0.0) - (x <= 0.0))
def _depress(self, w: float, dw: float) -> float:
r"""Apply LTD weight depression with hard lower bound clipping.
Parameters
----------
w : float
Current weight.
dw : float
Depression magnitude (positive value reduces weight).
Returns
-------
float
Updated weight: ``max(Wmin, w - dw)``.
"""
w_new = w - float(dw)
return w_new if w_new > self.Wmin else self.Wmin
def _facilitate(self, w: float, dw: float, x_trace: float) -> float:
r"""Apply LTP weight facilitation with hard upper bound clipping.
Parameters
----------
w : float
Current weight.
dw : float
Facilitation magnitude from postsynaptic LTP trace.
x_trace : float
Decayed presynaptic trace value at effective time.
Returns
-------
float
Updated weight: ``min(Wmax, w + dw * x_trace)``.
"""
w_new = w + float(dw) * float(x_trace)
return w_new if w_new < self.Wmax else self.Wmax
def _get_ltp_history(self, target: Any, t1: float, t2: float):
r"""Query postsynaptic LTP history entries in time interval (t1, t2].
Parameters
----------
target : object
Postsynaptic neuron with ``get_LTP_history`` or ``get_ltp_history`` method.
t1 : float
Start time (exclusive) in milliseconds.
t2 : float
End time (inclusive) in milliseconds.
Returns
-------
list or iterable
LTP history entries in the interval. Returns empty list if method returns ``None``.
Raises
------
AttributeError
If target does not provide ``get_LTP_history`` or ``get_ltp_history`` callable method.
"""
fn = getattr(target, 'get_LTP_history', None)
if fn is None:
fn = getattr(target, 'get_ltp_history', None)
if fn is None or not callable(fn):
raise AttributeError(
'Target must provide get_LTP_history(t1, t2) or get_ltp_history(t1, t2).'
)
history = fn(float(t1), float(t2))
if history is None:
return []
return history
def _get_ltd_value(self, target: Any, t: float) -> float:
r"""Query postsynaptic LTD depression amplitude at given time.
Parameters
----------
target : object
Postsynaptic neuron with ``get_LTD_value`` or ``get_ltd_value`` method.
t : float
Query time in milliseconds.
Returns
-------
float
LTD depression magnitude at time ``t``.
Raises
------
AttributeError
If target does not provide ``get_LTD_value`` or ``get_ltd_value`` callable method.
"""
fn = getattr(target, 'get_LTD_value', None)
if fn is None:
fn = getattr(target, 'get_ltd_value', None)
if fn is None or not callable(fn):
raise AttributeError(
'Target must provide get_LTD_value(t) or get_ltd_value(t).'
)
return float(fn(float(t)))
@staticmethod
def _extract_history_entry(entry: Any) -> tuple[float, float]:
r"""Extract time and weight change from LTP history entry.
Supports multiple entry formats:
- Dictionary: keys ``'t'``/``'t_'``/``'time'``/``'time_ms'`` and
``'dw'``/``'dw_'``/``'delta_w'``/``'weight_change'``
- 2-tuple or list: ``(time, dw)``
- Object: attributes ``t``/``t_`` and ``dw``/``dw_``
Parameters
----------
entry : dict or tuple or object
LTP history entry from postsynaptic neuron.
Returns
-------
tuple[float, float]
Extracted ``(time_ms, dw)`` as floats.
Raises
------
ValueError
If entry does not provide extractable time and weight change values.
"""
t = None
dw = None
if isinstance(entry, dict):
t = entry.get('t_', entry.get('t', entry.get('time_ms', entry.get('time', None))))
dw = entry.get('dw_', entry.get('dw', entry.get('delta_w', entry.get('weight_change', None))))
elif isinstance(entry, (tuple, list)) and len(entry) >= 2:
t, dw = entry[0], entry[1]
else:
t = getattr(entry, 't_', getattr(entry, 't', None))
dw = getattr(entry, 'dw_', getattr(entry, 'dw', None))
if t is None or dw is None:
raise ValueError('Each LTP history entry must provide both time and dw values.')
return float(t), float(dw)
@staticmethod
def _to_float_scalar(value: ArrayLike, name: str) -> float:
r"""Convert array-like value to scalar float with validation.
Strips saiunit Quantity wrapper if present, converts to float64 array, and validates
scalar shape and finite value.
Parameters
----------
value : array-like
Input value (may be saiunit Quantity, NumPy array, JAX array, or Python scalar).
name : str
Parameter name for error messages.
Returns
-------
float
Validated scalar float value.
Raises
------
ValueError
If value is not scalar (size != 1) or not finite (NaN or Inf).
"""
if isinstance(value, u.Quantity):
value = u.get_mantissa(value)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr[0])
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
return v
@staticmethod
def _to_int_scalar(value: ArrayLike, name: str) -> int:
r"""Convert array-like value to scalar integer with validation.
Strips saiunit Quantity wrapper if present, converts to float64, validates scalar shape
and finite value, then rounds to nearest integer and checks for integer-valued input.
Parameters
----------
value : array-like
Input value (may be saiunit Quantity, array, or scalar).
name : str
Parameter name for error messages.
Returns
-------
int
Validated integer value.
Raises
------
ValueError
If value is not scalar, not finite, or not integer-valued (|value - round(value)| > 1e-12).
"""
if isinstance(value, u.Quantity):
value = u.get_mantissa(value)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr[0])
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
vr = int(round(v))
if abs(v - vr) > 1e-12:
raise ValueError(f'{name} must be integer-valued.')
return vr
def _validate_positive_delay(self, value: ArrayLike) -> float:
r"""Validate and convert delay to positive float scalar.
Parameters
----------
value : array-like
Delay value in milliseconds.
Returns
-------
float
Validated delay (must be > 0).
Raises
------
ValueError
If delay <= 0, not scalar, or not finite.
"""
d = self._to_float_scalar(value, name='delay')
if d <= 0.0:
raise ValueError('delay must be > 0.')
return d
def _validate_delay_steps(self, value: ArrayLike) -> int:
r"""Validate and convert delay_steps to positive integer scalar.
Parameters
----------
value : array-like
Integer delay in simulation time steps.
Returns
-------
int
Validated delay_steps (must be >= 1).
Raises
------
ValueError
If delay_steps < 1, not scalar, not integer-valued, or not finite.
"""
d = self._to_int_scalar(value, name='delay_steps')
if d < 1:
raise ValueError('delay_steps must be >= 1.')
return d
def _validate_multiplicity(self, value: ArrayLike) -> float:
r"""Validate and convert multiplicity to non-negative float scalar.
Parameters
----------
value : array-like
Spike multiplicity value.
Returns
-------
float
Validated multiplicity (must be >= 0).
Raises
------
ValueError
If multiplicity < 0, not scalar, or not finite.
"""
m = self._to_float_scalar(value, name='multiplicity')
if m < 0.0:
raise ValueError('multiplicity must be >= 0.')
return m