Source code for brainpy_state._nest.vogels_sprekeler_synapse

import math
from typing import Any

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike

from ._base import NESTSynapse

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-

__all__ = [
    'vogels_sprekeler_synapse',
]


class vogels_sprekeler_synapse(NESTSynapse):
    r"""NEST-compatible ``vogels_sprekeler_synapse`` connection model.

    This class reproduces connection-level semantics of NEST
    ``models/vogels_sprekeler_synapse.{h,cpp}``, implementing inhibitory
    spike-timing-dependent plasticity (iSTDP) following Vogels & Sprekeler (2011).
    The rule combines symmetric STDP (no depression for post-before-pre timing)
    with constant presynaptic depression, designed to maintain excitatory-inhibitory
    balance in recurrent networks.

    **1. Mathematical Model**

    The learning rule modifies synaptic weight :math:`w` based on spike timing:

    .. math::

        \Delta w = \begin{cases}
        \eta \cdot K_- & \text{(pre-post pairing)} \\
        -\alpha \eta & \text{(constant depression per pre spike)}
        \end{cases}

    where:

    - :math:`K_-` is the postsynaptic trace (decaying exponentially with time constant :math:`\tau`)
    - :math:`\eta` is the learning rate
    - :math:`\alpha` is the constant depression factor

    Traces evolve as:

    .. math::

        \frac{dK_+}{dt} = -\frac{K_+}{\tau} + \sum_i \delta(t - t_i^{\text{pre}})

        \frac{dK_-}{dt} = -\frac{K_-}{\tau} + \sum_j \delta(t - t_j^{\text{post}})

    **2. NEST Send-Order Update Sequence**

    For one presynaptic spike at time :math:`t` with dendritic delay :math:`d`:

    1. **History lookup**: Retrieve postsynaptic spikes in :math:`(t_{\text{last}} - d, t - d]`
    2. **Pairwise facilitation**: For each postsynaptic spike :math:`t_j` in history:

       .. math::

           w \leftarrow \operatorname{facilitate}\!\left(
           w, K_+ \exp\left(\frac{t_{\text{last}} - (t_j + d)}{\tau}\right)\right)

    3. **Current postsynaptic trace facilitation**:

       .. math::

           w \leftarrow \operatorname{facilitate}(w, K_-(t - d))

    4. **Constant depression**:

       .. math::

           w \leftarrow \operatorname{depress}(w)

    5. **Emit spike event** using updated weight
    6. **Update presynaptic trace**:

       .. math::

           K_+ \leftarrow K_+ \exp\left(\frac{t_{\text{last}} - t}{\tau}\right) + 1

    7. **Update timestamp**: :math:`t_{\text{last}} \leftarrow t`

    **3. Weight Clipping Rules**

    Facilitate and depress operations are sign-aware via :math:`W_{\max}`:

    .. math::

        \operatorname{facilitate}(w, k) =
        \operatorname{copysign}\left(\min(|w| + \eta k,\ |W_{\max}|), W_{\max}\right)

    .. math::

        \operatorname{depress}(w) =
        \operatorname{copysign}\left(\max(|w| - \alpha\eta,\ 0), W_{\max}\right)

    This ensures weights saturate at :math:`\pm |W_{\max}|` while preserving sign.

    **4. Biological Motivation**

    This rule implements the iSTDP mechanism proposed by Vogels & Sprekeler for
    inhibitory synapses. The constant depression term :math:`\alpha` causes weight
    decay independent of post-pre timing, while facilitation occurs for pre-post
    pairings. This asymmetry drives inhibitory weights to track excitatory activity,
    maintaining balanced network states without fine-tuned parameters.

    Parameters
    ----------
    weight : float, array-like, optional
        Synaptic weight (unitless). Can be positive (excitatory) or negative
        (inhibitory). Must have same sign as ``Wmax`` if non-zero.
        Default: ``0.5``.
    delay : float, array-like, optional
        Dendritic delay in milliseconds used for spike history lookup. Must be
        positive. Determines time window for postsynaptic spike detection.
        Default: ``1.0`` ms.
    delay_steps : int, array-like, optional
        Event delivery delay in integer simulation steps (≥1). Controls when
        spike arrives at postsynaptic target after emission.
        Default: ``1``.
    tau : float, array-like, optional
        STDP time constant in milliseconds (>0). Governs exponential decay of
        presynaptic (:math:`K_+`) and postsynaptic (:math:`K_-`) traces.
        Typical range: 10-50 ms.
        Default: ``20.0`` ms.
    alpha : float, array-like, optional
        Constant depression factor (unitless). Scales the per-spike weight
        reduction: :math:`\Delta w = -\alpha \eta`. Setting :math:`\alpha = 0`
        disables constant depression (pure Hebbian STDP).
        Default: ``0.12``.
    eta : float, array-like, optional
        Learning rate (unitless). Scales both facilitation and depression.
        Smaller values (≪1) ensure gradual weight changes.
        Default: ``0.001``.
    Wmax : float, array-like, optional
        Signed maximum absolute weight (unitless). Defines saturation bounds
        :math:`[-|W_{\max}|, +|W_{\max}|]` and determines sign of weight dynamics.
        Must have same sign as ``weight`` (if ``weight != 0``).
        Default: ``1.0``.
    Kplus : float, array-like, optional
        Initial presynaptic STDP trace value (unitless, ≥0). Represents accumulated
        presynaptic activity. Typically initialized to 0 before simulation.
        Default: ``0.0``.
    t_last_spike_ms : float, array-like, optional
        Timestamp of last presynaptic spike in milliseconds. Used for trace decay
        calculations. Initialize to simulation start time or 0.
        Default: ``0.0`` ms.
    name : str, optional
        Model instance name for identification.
        Default: ``None``.

    See Also
    --------
    stdp_synapse : Classical asymmetric STDP rule
    stdp_dopamine_synapse : Reward-modulated STDP

    Parameter Mapping
    -----------------

    +-----------------+---------------------+-------------------+
    | NEST Parameter  | brainpy.state       | Unit              |
    +=================+=====================+===================+
    | ``weight``      | ``weight``          | unitless          |
    +-----------------+---------------------+-------------------+
    | ``delay``       | ``delay``           | ms                |
    +-----------------+---------------------+-------------------+
    | ``delay_steps`` | ``delay_steps``     | steps             |
    +-----------------+---------------------+-------------------+
    | ``tau``         | ``tau``             | ms                |
    +-----------------+---------------------+-------------------+
    | ``alpha``       | ``alpha``           | unitless          |
    +-----------------+---------------------+-------------------+
    | ``eta``         | ``eta``             | unitless          |
    +-----------------+---------------------+-------------------+
    | ``Wmax``        | ``Wmax``            | unitless          |
    +-----------------+---------------------+-------------------+
    | ``Kplus``       | ``Kplus``           | unitless          |
    +-----------------+---------------------+-------------------+
    | ``t_lastspike`` | ``t_last_spike_ms`` | ms                |
    +-----------------+---------------------+-------------------+

    **Target Interface Requirements**

    The ``send()`` method requires postsynaptic targets to implement:

    - ``get_history(t1, t2)`` -- Returns spike history entries in time window
      ``(t1, t2]`` (exclusive-inclusive). Entries must expose spike time via
      attribute ``t_`` or ``t``, dict key ``'t_'`` or ``'t'``, or first tuple
      element.

    - ``get_K_value(t)`` or ``get_k_value(t)`` -- Returns postsynaptic STDP trace
      :math:`K_-` at time ``t`` (in ms). Must return float.

    1. **Dendritic delay semantics**: Unlike axonal delays (which shift spike
       arrival time), the ``delay`` parameter here controls the temporal window
       for history lookup: :math:`(t_{\text{last}} - d, t - d]`. This implements
       NEST's dendritic delay convention.

    2. **Sign constraints**: If ``weight != 0``, both ``weight`` and ``Wmax`` must
       have the same sign. Attempting to set opposite signs raises ``ValueError``.
       This preserves synapse type (excitatory/inhibitory) throughout learning.

    3. **Trace positivity**: ``Kplus`` must remain non-negative. Negative values
       raise ``ValueError`` during initialization or ``set_status()``.

    4. **Sub-step timing**: As in NEST, precise spike times within a time step
       (e.g., off-grid timestamps) are **ignored** for plasticity calculations.
       All updates use coarse time step boundaries.

    5. **Event multiplicity**: The ``multiplicity`` parameter in ``send()`` is
       validated but not explicitly used in weight updates (reserved for future
       multi-spike events).

    References
    ----------
    .. [1] Vogels, T. P., & Sprekeler, H. (2011). Inhibitory plasticity balances
           excitation and inhibition in sensory pathways and memory networks.
           *Science*, 334(6062), 1569-1573.
           https://doi.org/10.1126/science.1211095

    .. [2] NEST Initiative (2024). Vogels-Sprekeler Synapse Model.
           NEST Simulator Documentation.
           https://nest-simulator.readthedocs.io/en/stable/models/vogels_sprekeler_synapse.html

    Examples
    --------
    Basic synapse creation with default parameters:

    .. code-block:: python

        >>> import brainpy.state as bp
        >>> syn = bp.vogels_sprekeler_synapse()
        >>> syn.get_status()
        {'weight': 0.5, 'tau': 20.0, 'alpha': 0.12, 'eta': 0.001, ...}

    Configure for inhibitory synapse with stronger depression:

    .. code-block:: python

        >>> syn = bp.vogels_sprekeler_synapse(
        ...     weight=-0.8,
        ...     Wmax=-2.0,
        ...     alpha=0.2,
        ...     eta=0.005,
        ...     tau=30.0
        ... )
        >>> syn.weight
        -0.8

    Update parameters after creation:

    .. code-block:: python

        >>> syn.set_status({'alpha': 0.15, 'eta': 0.002})
        >>> syn.alpha
        0.15

    Process presynaptic spike with mock postsynaptic target:

    .. code-block:: python

        >>> class MockNeuron:
        ...     def get_history(self, t1, t2):
        ...         # Return spike at t=10.5 ms
        ...         return [{'t_': 10.5}]
        ...     def get_K_value(self, t):
        ...         return 0.3  # Current postsynaptic trace
        >>> target = MockNeuron()
        >>> event = syn.send(t_spike_ms=15.0, target=target)
        >>> event['weight']  # Updated weight after plasticity
        0.506...
        >>> syn.Kplus  # Updated presynaptic trace
        1.0

    Simulate spike train:

    .. code-block:: python

        >>> pre_times = [10.0, 20.0, 30.0, 40.0]
        >>> events = syn.simulate_pre_spike_train(
        ...     pre_spike_times_ms=pre_times,
        ...     target=target
        ... )
        >>> len(events)
        4
        >>> [e['weight'] for e in events]  # Weight evolution
        [0.512..., 0.518..., 0.524..., 0.530...]
    """

    __module__ = 'brainpy.state'

    HAS_DELAY = True
    IS_PRIMARY = True
    SUPPORTS_HPC = True
    SUPPORTS_LBL = True
    SUPPORTS_WFR = True

    def __init__(
        self,
        weight: ArrayLike = 0.5,
        delay: ArrayLike = 1.0,
        delay_steps: ArrayLike = 1,
        tau: ArrayLike = 20.0,
        alpha: ArrayLike = 0.12,
        eta: ArrayLike = 0.001,
        Wmax: ArrayLike = 1.0,
        Kplus: ArrayLike = 0.0,
        t_last_spike_ms: ArrayLike = 0.0,
        name: str | None = None,
    ):
        super().__init__(in_size=1, name=name)

        self.weight = self._to_float_scalar(weight, name='weight')
        self.delay = self._validate_positive_delay(delay)
        self.delay_steps = self._validate_delay_steps(delay_steps)
        self.tau = self._validate_positive_tau(tau)
        self.alpha = self._to_float_scalar(alpha, name='alpha')
        self.eta = self._to_float_scalar(eta, name='eta')
        self.Wmax = self._to_float_scalar(Wmax, name='Wmax')
        self.Kplus = self._to_float_scalar(Kplus, name='Kplus')
        self.t_last_spike_ms = self._to_float_scalar(t_last_spike_ms, name='t_last_spike_ms')

        self._check_constraints()

    @property
    def properties(self) -> dict[str, Any]:
        r"""Return model properties dictionary.

        Returns
        -------
        dict[str, Any]
            Dictionary with keys:
            - ``'has_delay'`` (bool): True (model supports synaptic delays).
            - ``'is_primary'`` (bool): True (model is a primary connection).
            - ``'supports_hpc'`` (bool): True (hybrid parallel computing compatible).
            - ``'supports_lbl'`` (bool): True (supports local backward learning).
            - ``'supports_wfr'`` (bool): True (supports waveform relaxation).
        """
        return {
            'has_delay': self.HAS_DELAY,
            'is_primary': self.IS_PRIMARY,
            'supports_hpc': self.SUPPORTS_HPC,
            'supports_lbl': self.SUPPORTS_LBL,
            'supports_wfr': self.SUPPORTS_WFR,
        }

[docs] def get_status(self) -> dict[str, Any]: r"""Retrieve current parameter and state values. Returns all synapse parameters, STDP trace state, and model properties as a dictionary. Follows NEST ``GetStatus`` semantics. Returns ------- dict[str, Any] Dictionary containing: - ``'weight'`` (float): Current synaptic weight (unitless). - ``'delay'`` (float): Dendritic delay (ms). - ``'delay_steps'`` (int): Event delivery delay (steps). - ``'tau'`` (float): STDP time constant (ms). - ``'alpha'`` (float): Constant depression factor (unitless). - ``'eta'`` (float): Learning rate (unitless). - ``'Wmax'`` (float): Signed maximum weight (unitless). - ``'Kplus'`` (float): Presynaptic STDP trace (unitless). - ``'t_last_spike_ms'`` (float): Last presynaptic spike time (ms). - ``'size_of'`` (int): Memory footprint in bytes. - ``'has_delay'`` (bool): Delay support flag. - ``'is_primary'`` (bool): Primary connection flag. - ``'supports_hpc'`` (bool): HPC compatibility flag. - ``'supports_lbl'`` (bool): LBL compatibility flag. - ``'supports_wfr'`` (bool): WFR compatibility flag. Examples -------- .. code-block:: python >>> syn = bp.vogels_sprekeler_synapse(weight=0.7, tau=25.0) >>> status = syn.get_status() >>> status['weight'] 0.7 >>> status['tau'] 25.0 """ return { 'weight': float(self.weight), 'delay': float(self.delay), 'delay_steps': int(self.delay_steps), 'tau': float(self.tau), 'alpha': float(self.alpha), 'eta': float(self.eta), 'Wmax': float(self.Wmax), 'Kplus': float(self.Kplus), 't_last_spike_ms': float(self.t_last_spike_ms), 'size_of': int(self.__sizeof__()), 'has_delay': self.HAS_DELAY, 'is_primary': self.IS_PRIMARY, 'supports_hpc': self.SUPPORTS_HPC, 'supports_lbl': self.SUPPORTS_LBL, 'supports_wfr': self.SUPPORTS_WFR, }
[docs] def set_status(self, status: dict[str, Any] | None = None, **kwargs): r"""Update synapse parameters and state variables. Modifies synapse configuration following NEST ``SetStatus`` semantics. Validates all updates and enforces constraints (positive delays/tau, non-negative Kplus, matching weight/Wmax signs). Parameters ---------- status : dict[str, Any], optional Dictionary of parameter updates. Valid keys: ``'weight'``, ``'delay'``, ``'delay_steps'``, ``'tau'``, ``'alpha'``, ``'eta'``, ``'Wmax'``, ``'Kplus'``, ``'t_last_spike_ms'``. **kwargs Additional parameter updates as keyword arguments. Merged with ``status`` dict (kwargs take precedence). Raises ------ ValueError If ``delay ≤ 0``, ``tau ≤ 0``, ``delay_steps < 1``, ``Kplus < 0``, or ``weight`` and ``Wmax`` have opposite signs (when ``weight != 0``). TypeError If parameter values are not scalar or convertible to required numeric type. Notes ----- Constraint checking runs **after** all updates are applied, so transient inconsistent states (e.g., setting weight before Wmax) are allowed within a single call. Examples -------- Update single parameter: .. code-block:: python >>> syn = bp.vogels_sprekeler_synapse() >>> syn.set_status({'alpha': 0.15}) >>> syn.alpha 0.15 Update multiple parameters at once: .. code-block:: python >>> syn.set_status({'eta': 0.002, 'tau': 30.0}) >>> syn.eta, syn.tau (0.002, 30.0) Use keyword arguments: .. code-block:: python >>> syn.set_status(weight=0.8, Wmax=2.0) >>> syn.weight, syn.Wmax (0.8, 2.0) Invalid updates raise errors: .. code-block:: python >>> syn.set_status({'tau': -5.0}) # doctest: +SKIP ValueError: tau must be > 0. >>> syn.set_status({'Kplus': -0.1}) # doctest: +SKIP ValueError: State Kplus must be positive. """ updates = {} if status is not None: updates.update(status) updates.update(kwargs) if 'weight' in updates: self.weight = self._to_float_scalar(updates['weight'], name='weight') if 'delay' in updates: self.delay = self._validate_positive_delay(updates['delay']) if 'delay_steps' in updates: self.delay_steps = self._validate_delay_steps(updates['delay_steps']) if 'tau' in updates: self.tau = self._validate_positive_tau(updates['tau']) if 'alpha' in updates: self.alpha = self._to_float_scalar(updates['alpha'], name='alpha') if 'eta' in updates: self.eta = self._to_float_scalar(updates['eta'], name='eta') if 'Wmax' in updates: self.Wmax = self._to_float_scalar(updates['Wmax'], name='Wmax') if 'Kplus' in updates: self.Kplus = self._to_float_scalar(updates['Kplus'], name='Kplus') if 't_last_spike_ms' in updates: self.t_last_spike_ms = self._to_float_scalar(updates['t_last_spike_ms'], name='t_last_spike_ms') self._check_constraints()
[docs] def get(self, key: str = 'status'): r"""Retrieve specific parameter or full status dictionary. Parameters ---------- key : str, optional Parameter name or ``'status'`` for full dictionary. Valid keys: ``'weight'``, ``'delay'``, ``'delay_steps'``, ``'tau'``, ``'alpha'``, ``'eta'``, ``'Wmax'``, ``'Kplus'``, ``'t_last_spike_ms'``, ``'size_of'``, ``'has_delay'``, ``'is_primary'``, ``'supports_hpc'``, ``'supports_lbl'``, ``'supports_wfr'``. Default: ``'status'``. Returns ------- Any If ``key == 'status'``, returns full status dictionary (see ``get_status()``). Otherwise, returns the requested parameter value. Raises ------ KeyError If ``key`` is not recognized. Examples -------- .. code-block:: python >>> syn = bp.vogels_sprekeler_synapse(weight=0.6, tau=25.0) >>> syn.get('weight') 0.6 >>> syn.get('tau') 25.0 >>> status = syn.get('status') >>> 'weight' in status True """ if key == 'status': return self.get_status() status = self.get_status() if key in status: return status[key] raise KeyError(f'Unsupported key "{key}" for vogels_sprekeler_synapse.get().')
[docs] def send( self, t_spike_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> dict[str, Any]: r"""Process one presynaptic spike and return emitted SpikeEvent payload. Implements the complete Vogels-Sprekeler STDP update sequence for a single presynaptic spike: retrieves postsynaptic spike history, applies pairwise facilitation, facilitates with current postsynaptic trace, applies constant depression, updates presynaptic trace, and emits spike event. Parameters ---------- t_spike_ms : float, array-like Presynaptic spike time in milliseconds (scalar). target : Any Postsynaptic target object. Must implement: - ``get_history(t1, t2)`` -- iterable of spike entries in ``(t1, t2]``. Each entry must expose spike time via attribute ``t_``/``t``, dict key ``'t_'``/``'t'``, or first tuple element. - ``get_K_value(t)`` or ``get_k_value(t)`` -- float (postsynaptic trace :math:`K_-` at time ``t``). receptor_type : int, array-like, optional Postsynaptic receptor port index (≥0). Included in returned event dictionary for routing. Default: ``0``. multiplicity : float, array-like, optional Spike event multiplicity (≥0). Validated but not used in weight updates. Default: ``1.0``. delay : float, array-like, optional Override dendritic delay (ms, >0). If ``None``, uses ``self.delay``. Default: ``None``. delay_steps : int, array-like, optional Override event delivery delay (steps, ≥1). If ``None``, uses ``self.delay_steps``. Default: ``None``. Returns ------- dict[str, Any] Spike event dictionary with keys: - ``'weight'`` (float): Updated synaptic weight after plasticity. - ``'delay'`` (float): Effective dendritic delay (ms). - ``'delay_steps'`` (int): Event delivery delay (steps). - ``'receptor_type'`` (int): Postsynaptic receptor index. - ``'multiplicity'`` (float): Event multiplicity. - ``'t_spike_ms'`` (float): Presynaptic spike time (ms). - ``'Kminus'`` (float): Postsynaptic trace value at ``t_spike - delay``. - ``'Kplus_pre'`` (float): Presynaptic trace **before** update. - ``'Kplus_post'`` (float): Presynaptic trace **after** update. Raises ------ ValueError If ``delay ≤ 0``, ``delay_steps < 1``, ``multiplicity < 0``, or any parameter is non-scalar/non-finite. AttributeError If ``target`` does not implement required methods ``get_history()`` and ``get_K_value()``/``get_k_value()``. TypeError If history entries do not expose spike time via expected attributes/keys. Notes ----- 1. **State updates are persistent**: This method modifies ``self.weight``, ``self.Kplus``, and ``self.t_last_spike_ms`` in place. 2. **History lookup window**: Retrieves postsynaptic spikes in :math:`(t_{\text{last}} - d, t_{\text{spike}} - d]`, where :math:`d` is the dendritic delay. 3. **Trace decay calculation**: Presynaptic trace decays as :math:`K_+ \leftarrow K_+ \exp((t_{\text{last}} - t) / \tau) + 1`, ensuring continuous exponential decay between spikes. 4. **Weight clipping**: Facilitation and depression operations automatically clip weights to :math:`\pm |W_{\max}|` while preserving sign. Examples -------- Process single spike with mock target: .. code-block:: python >>> import brainpy.state as bp >>> class MockTarget: ... def get_history(self, t1, t2): ... return [{'t_': 12.0}] # One post spike at 12 ms ... def get_K_value(self, t): ... return 0.4 >>> syn = bp.vogels_sprekeler_synapse(weight=0.5, tau=20.0, eta=0.01) >>> target = MockTarget() >>> event = syn.send(t_spike_ms=15.0, target=target) >>> event['weight'] # Facilitated then depressed 0.502... >>> event['Kplus_post'] # Updated presynaptic trace 1.0 Override delay for specific spike: .. code-block:: python >>> event = syn.send( ... t_spike_ms=20.0, ... target=target, ... delay=2.5, ... delay_steps=3 ... ) >>> event['delay'] 2.5 >>> event['delay_steps'] 3 Access postsynaptic trace from event: .. code-block:: python >>> event['Kminus'] # Postsynaptic trace at spike time - delay 0.4 """ t_spike = self._to_float_scalar(t_spike_ms, name='t_spike_ms') dendritic_delay = self.delay if delay is None else self._validate_positive_delay(delay) event_delay_steps = ( self.delay_steps if delay_steps is None else self._validate_delay_steps(delay_steps) ) history_entries = self._get_history( target, self.t_last_spike_ms - dendritic_delay, t_spike - dendritic_delay, ) for entry in history_entries: t_hist = self._extract_history_time(entry) minus_dt = self.t_last_spike_ms - (t_hist + dendritic_delay) self.weight = self._facilitate(self.weight, self.Kplus * math.exp(minus_dt / self.tau)) kminus = self._get_k_value(target, t_spike - dendritic_delay) self.weight = self._facilitate(self.weight, kminus) self.weight = self._depress(self.weight) event = { 'weight': float(self.weight), 'delay': float(dendritic_delay), 'delay_steps': int(event_delay_steps), 'receptor_type': self._to_int_scalar(receptor_type, name='receptor_type'), 'multiplicity': self._validate_multiplicity(multiplicity), 't_spike_ms': float(t_spike), 'Kminus': float(kminus), 'Kplus_pre': float(self.Kplus), } self.Kplus = self.Kplus * math.exp((self.t_last_spike_ms - t_spike) / self.tau) + 1.0 self.t_last_spike_ms = t_spike event['Kplus_post'] = float(self.Kplus) return event
[docs] def to_spike_event( self, t_spike_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> dict[str, Any]: r"""Alias for ``send()`` method. Identical to ``send()`` with the same parameters and return value. Provided for API compatibility with alternative naming conventions. See Also -------- send : Primary spike processing method (full documentation). """ return self.send( t_spike_ms=t_spike_ms, target=target, receptor_type=receptor_type, multiplicity=multiplicity, delay=delay, delay_steps=delay_steps, )
[docs] def simulate_pre_spike_train( self, pre_spike_times_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> list[dict[str, Any]]: r"""Process a sequence of presynaptic spikes and return event list. Iteratively calls ``send()`` for each spike time, accumulating weight updates and trace dynamics across the entire spike train. Useful for simulating synapse evolution under controlled input patterns. Parameters ---------- pre_spike_times_ms : array-like Presynaptic spike times in milliseconds. Can be 1D array, list, or scalar. Automatically flattened to 1D. target : Any Postsynaptic target (see ``send()`` for interface requirements). receptor_type : int, array-like, optional Postsynaptic receptor port (see ``send()``). Default: ``0``. multiplicity : float, array-like, optional Event multiplicity (see ``send()``). Default: ``1.0``. delay : float, array-like, optional Dendritic delay override (ms, see ``send()``). Default: ``None``. delay_steps : int, array-like, optional Delivery delay override (steps, see ``send()``). Default: ``None``. Returns ------- list[dict[str, Any]] List of spike event dictionaries (one per input spike), in temporal order. Each dict has same structure as ``send()`` return value. Notes ----- **State evolution**: Because ``send()`` modifies synapse state (``weight``, ``Kplus``, ``t_last_spike_ms``), the returned events reflect **cumulative** plasticity. Event ``i`` depends on events ``0`` through ``i-1``. Examples -------- Simulate regular spike train: .. code-block:: python >>> import numpy as np >>> import brainpy.state as bp >>> class MockTarget: ... def get_history(self, t1, t2): ... # Postsynaptic spikes at 10, 30, 50 ms ... return [{'t_': t} for t in [10, 30, 50] if t1 < t <= t2] ... def get_K_value(self, t): ... return 0.3 >>> syn = bp.vogels_sprekeler_synapse(weight=0.5, eta=0.01) >>> target = MockTarget() >>> pre_times = np.arange(5, 60, 10) # Spikes at 5, 15, 25, 35, 45, 55 ms >>> events = syn.simulate_pre_spike_train(pre_times, target) >>> len(events) 6 >>> [e['weight'] for e in events] # Weight trajectory [0.498..., 0.502..., 0.505..., 0.509..., 0.512..., 0.515...] Extract presynaptic trace evolution: .. code-block:: python >>> kplus_trajectory = [e['Kplus_post'] for e in events] >>> kplus_trajectory[0] # After first spike 1.0 >>> kplus_trajectory[-1] # After last spike 1.0 Weight evolution with strong depression: .. code-block:: python >>> syn2 = bp.vogels_sprekeler_synapse( ... weight=1.0, ... alpha=0.5, ... eta=0.02 ... ) >>> events2 = syn2.simulate_pre_spike_train([10, 20, 30], target) >>> [e['weight'] for e in events2] # Depression dominates [0.996..., 0.992..., 0.988...] """ dftype = brainstate.environ.dftype() times = np.asarray(u.math.asarray(pre_spike_times_ms), dtype=dftype).reshape(-1) events = [] for t in times: events.append( self.send( t_spike_ms=float(t), target=target, receptor_type=receptor_type, multiplicity=multiplicity, delay=delay, delay_steps=delay_steps, ) ) return events
def _facilitate(self, w: float, kplus: float) -> float: new_w = abs(w) + self.eta * kplus return math.copysign(min(new_w, abs(self.Wmax)), self.Wmax) def _depress(self, w: float) -> float: new_w = abs(w) - self.alpha * self.eta return math.copysign(max(new_w, 0.0), self.Wmax) def _check_constraints(self): if self.Kplus < 0.0: raise ValueError('State Kplus must be positive.') if self.weight != 0.0 and (math.copysign(1.0, self.weight) != math.copysign(1.0, self.Wmax)): raise ValueError('Weight and Wmax must have same sign.') @staticmethod def _get_history(target: Any, t1: float, t2: float): if hasattr(target, 'get_history'): return target.get_history(float(t1), float(t2)) raise AttributeError( 'Target must provide get_history(t1, t2) for vogels_sprekeler_synapse.' ) @staticmethod def _extract_history_time(entry: Any) -> float: if hasattr(entry, 't_'): return float(entry.t_) if hasattr(entry, 't'): return float(entry.t) if isinstance(entry, dict): if 't_' in entry: return float(entry['t_']) if 't' in entry: return float(entry['t']) if isinstance(entry, (tuple, list)) and len(entry) >= 1: return float(entry[0]) raise TypeError( 'History entry must expose a time as attribute t_/t, mapping key t_/t, or first tuple element.' ) @staticmethod def _get_k_value(target: Any, t: float) -> float: if hasattr(target, 'get_K_value'): return float(target.get_K_value(float(t))) if hasattr(target, 'get_k_value'): return float(target.get_k_value(float(t))) raise AttributeError( 'Target must provide get_K_value(t) or get_k_value(t) for vogels_sprekeler_synapse.' ) @staticmethod def _to_float_scalar(value: ArrayLike, name: str) -> float: if isinstance(value, u.Quantity): value = u.get_mantissa(value) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be scalar.') v = float(arr[0]) if not np.isfinite(v): raise ValueError(f'{name} must be finite.') return v @staticmethod def _to_int_scalar(value: ArrayLike, name: str) -> int: if isinstance(value, u.Quantity): value = u.get_mantissa(value) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be scalar.') v = float(arr[0]) if not np.isfinite(v): raise ValueError(f'{name} must be finite.') iv = int(round(v)) if abs(v - iv) > 1e-12: raise ValueError(f'{name} must be an integer value.') return iv @classmethod def _validate_positive_delay(cls, value: ArrayLike) -> float: d = cls._to_float_scalar(value, name='delay') if d <= 0.0: raise ValueError('delay must be > 0.') return d @classmethod def _validate_delay_steps(cls, value: ArrayLike) -> int: d = cls._to_int_scalar(value, name='delay_steps') if d < 1: raise ValueError('delay_steps must be >= 1.') return d @classmethod def _validate_positive_tau(cls, value: ArrayLike) -> float: tau = cls._to_float_scalar(value, name='tau') if tau <= 0.0: raise ValueError('tau must be > 0.') return tau @classmethod def _validate_multiplicity(cls, value: ArrayLike) -> float: m = cls._to_float_scalar(value, name='multiplicity') if m < 0.0: raise ValueError('multiplicity must be >= 0.') return m