import math
from typing import Any
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike
from ._base import NESTSynapse
# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
__all__ = [
'jonke_synapse',
]
class jonke_synapse(NESTSynapse):
r"""NEST-compatible ``jonke_synapse`` connection model with weight-dependent STDP.
Implements spike-timing-dependent plasticity with exponential weight dependence and
additive offsets, following NEST's ``jonke_synapse`` semantics. The model applies
multiplicative weight factors :math:`\exp(\mu w)` to both facilitation and depression
branches, producing nonlinear weight dynamics that can stabilize synaptic strengths
or implement homeostatic control.
**1. Mathematical Formulation**
The plasticity rule operates on synaptic weight :math:`w(t)` using presynaptic trace
:math:`K_+(t)` (with time constant :math:`\tau_+`) and postsynaptic trace :math:`K_-(t)`:
.. math::
\frac{dK_+}{dt} &= -\frac{K_+}{\tau_+} + \sum_f \delta(t - t_f^{\text{pre}}) \\
\frac{dK_-}{dt} &= -\frac{K_-}{\tau_-} + \sum_j \delta(t - t_j^{\text{post}})
**Weight-dependent plasticity kernels:**
.. math::
\Phi_+(w) &= \exp(\mu_+ w) \\
\Phi_-(w) &= \exp(\mu_- w)
**Update rules (applied at spike times):**
.. math::
\Delta w_+ &= \lambda \left( \Phi_+(w) K_+ - \beta \right) \quad \text{(facilitation)} \\
\Delta w_- &= \lambda \left( -\alpha \Phi_-(w) K_- - \beta \right) \quad \text{(depression)}
The weight is hard-bounded to :math:`[0, W_{\max}]` after each update.
**2. Temporal Dynamics**
At each presynaptic spike at time :math:`t`:
1. **History lookup:** Read postsynaptic spikes in :math:`(t_{\text{last}} - d,\; t - d]`
2. **Facilitation pass:** For each post-spike :math:`t_j` in history:
.. math::
K_+^{\text{eff}} = K_+(t_{\text{last}}) \exp\left(\frac{t_{\text{last}} - (t_j + d)}{\tau_+}\right)
w \leftarrow w + \lambda \left( \Phi_+(w) K_+^{\text{eff}} - \beta \right)
3. **Depression:** Using postsynaptic :math:`K_-(t - d)`:
.. math::
w \leftarrow w + \lambda \left( -\alpha \Phi_-(w) K_-(t-d) - \beta \right)
4. **Event emission:** Spike delivered with updated weight
5. **Trace update:**
.. math::
K_+ \leftarrow K_+ \exp\left(\frac{t_{\text{last}} - t}{\tau_+}\right) + 1
6. **State update:** :math:`t_{\text{last}} \leftarrow t`
**3. Design Considerations**
- **Exponential weight dependence:** :math:`\mu_+ > 0` implements soft upper bound
(larger weights resist growth); :math:`\mu_+ < 0` enables runaway potentiation.
Similarly for :math:`\mu_-` and depression.
- **Additive offset :math:`\beta`:** Shifts both update branches uniformly. Positive
:math:`\beta` biases toward depression, negative toward potentiation. Can implement
heterosynaptic competition.
- **Asymmetric depression scaling:** :math:`\alpha` allows independent control of
depression amplitude relative to potentiation.
- **Numerical stability:** For large :math:`|\mu w|`, :math:`\exp(\mu w)` may
overflow/underflow. Consider :math:`\mu` values keeping :math:`|\mu W_{\max}| < 10`.
**4. Computational Properties**
- **Time complexity:** :math:`O(N_{\text{post-spikes}})` per presynaptic spike, where
:math:`N_{\text{post-spikes}}` is the count in delay window.
- **Dendritic delay semantics:** History lookups use :math:`t - d` (compensating for
backpropagation time). Event delivery uses `delay_steps` (axonal propagation).
- **Precision:** Sub-grid spike timing (offset component in NEST) is ignored; all
updates use grid-aligned times only.
Parameters
----------
weight : float or array-like, default=1.0
Initial synaptic efficacy (dimensionless or pA/mV). Must be finite. Updated during
plasticity and bounded to :math:`[0, W_{\max}]`.
delay : float or array-like, default=1.0
Dendritic delay in milliseconds for history lookups and depression timing
(:math:`d` in equations). Must be positive.
delay_steps : int or array-like, default=1
Axonal event delivery delay in simulation time steps. Must be ≥ 1. Typically set
to match `delay` quantized to grid resolution.
Kplus : float or array-like, default=0.0
Initial presynaptic trace value :math:`K_+(0)`. Must be non-negative. Evolves
according to :math:`\tau_+` dynamics.
t_last_spike_ms : float or array-like, default=0.0
Timestamp of last presynaptic spike in milliseconds. Used for trace decay
computation between spikes.
alpha : float or array-like, default=1.0
Depression amplitude scaling factor. :math:`\alpha = 1` gives symmetric update
magnitudes (when :math:`\mu_+ = \mu_- = 0, \beta = 0`).
beta : float or array-like, default=0.0
Additive offset applied to both update branches (dimensionless). Positive values
bias toward depression. Enables heterosynaptic effects.
lambda_ : float or array-like, default=0.01
Learning rate :math:`\lambda`. Controls plasticity time scale. Set to 0 to disable
learning. Typical values: :math:`10^{-4}` to :math:`10^{-1}`.
mu_plus : float or array-like, default=0.0
Facilitation weight dependence exponent :math:`\mu_+`. Positive values create soft
upper bound. Units: inverse of weight units.
mu_minus : float or array-like, default=0.0
Depression weight dependence exponent :math:`\mu_-`. Positive values accelerate
depression at high weights.
tau_plus : float or array-like, default=20.0
Presynaptic trace time constant in milliseconds :math:`\tau_+`. Controls
potentiation temporal window. Typical range: 10–40 ms.
Wmax : float or array-like, default=100.0
Hard upper weight bound :math:`W_{\max}`. Weights exceeding this after updates are
clipped. Lower bound is always 0.
name : str or None, optional
Instance identifier for debugging and logging.
Parameter Mapping
-----------------
================================ ================================ =================
NEST Parameter brainpy.state Parameter Units / Notes
================================ ================================ =================
``weight`` ``weight`` dimensionless
``delay`` ``delay`` ms
``delay_steps`` ``delay_steps`` steps
``Kplus`` ``Kplus`` dimensionless
``t_lastspike`` ``t_last_spike_ms`` ms
``alpha`` ``alpha`` dimensionless
``beta`` ``beta`` dimensionless
``lambda`` ``lambda_`` dimensionless
``mu_plus`` ``mu_plus`` 1/weight
``mu_minus`` ``mu_minus`` 1/weight
``tau_plus`` ``tau_plus`` ms
``Wmax`` ``Wmax`` same as weight
================================ ================================ =================
Raises
------
ValueError
- If ``Kplus < 0`` (violates trace non-negativity).
- If ``delay <= 0`` (non-physical delay).
- If ``delay_steps < 1`` (invalid event scheduling).
- If any parameter is non-finite (NaN or ±inf).
- If scalar parameters have size ≠ 1.
Notes
-----
- **Target interface requirements:** The postsynaptic target object passed to ``send()``
must implement:
* ``get_history(t1, t2) -> iterable``: Returns postsynaptic spike times in
:math:`(t_1, t_2]`. Each entry is an object/dict/tuple with time accessible via
``.t_``, ``.t``, ``['t_']``, ``['t']``, or first element.
* ``get_K_value(t) -> float`` or ``get_k_value(t) -> float``: Returns depression
trace :math:`K_-(t)` at time `t`.
- **NEST compatibility:** Reproduces behavior of ``nest-simulator/models/jonke_synapse.cpp``
including parameter validation, update ordering, and spike event payload structure.
- **Sub-grid timing:** Unlike NEST's precise spike timing mode, this implementation uses
grid-aligned times only (ignoring offset components).
- **Homeostatic interpretation:** With :math:`\mu_+ > 0` and appropriate :math:`\beta`,
the model can implement sliding threshold mechanisms that stabilize weight distributions.
Examples
--------
**Basic STDP with linear weight dependence:**
.. code-block:: python
>>> import brainpy.state as bp
>>> # Create synapse with standard STDP parameters
>>> syn = bp.jonke_synapse(
... weight=5.0,
... delay=1.0,
... lambda_=0.01,
... tau_plus=20.0,
... alpha=1.0,
... beta=0.0,
... mu_plus=0.0,
... mu_minus=0.0,
... Wmax=10.0
... )
>>> syn.get_status()['weight']
5.0
**Exponential weight dependence (soft bounds):**
.. code-block:: python
>>> # Mu_plus > 0 creates resistance to potentiation at high weights
>>> syn_bounded = bp.jonke_synapse(
... weight=1.0,
... lambda_=0.01,
... mu_plus=0.1,
... mu_minus=0.05,
... Wmax=20.0
... )
>>> # At w=10: Phi_+(10) = exp(0.1*10) = 2.72 (potentiation enhanced)
>>> # At w=0: Phi_+(0) = 1.0 (baseline)
**Heterosynaptic plasticity via beta offset:**
.. code-block:: python
>>> # Positive beta biases toward depression
>>> syn_hetero = bp.jonke_synapse(
... weight=5.0,
... lambda_=0.005,
... beta=0.05,
... alpha=1.2
... )
>>> # All weights slowly decay even without post-spikes (beta term)
**Simulate spike-pair interaction:**
.. code-block:: python
>>> class MockTarget:
... def get_history(self, t1, t2):
... # Return single post-spike at t=15 ms
... if t1 < 15.0 <= t2:
... return [{'t_': 15.0}]
... return []
... def get_k_value(self, t):
... # Kminus trace at depression check time
... return 0.8
>>>
>>> target = MockTarget()
>>> syn = bp.jonke_synapse(weight=5.0, lambda_=0.01, tau_plus=20.0)
>>>
>>> # Pre-spike at t=10 ms (pre-before-post → no facilitation yet)
>>> event1 = syn.send(t_spike_ms=10.0, target=target)
>>> print(f"Weight after pre@10: {event1['weight']:.3f}")
Weight after pre@10: 4.992
>>>
>>> # Pre-spike at t=20 ms (post@15 in history → facilitation)
>>> event2 = syn.send(t_spike_ms=20.0, target=target)
>>> print(f"Weight after pre@20: {event2['weight']:.3f}")
Weight after pre@20: 5.034
See Also
--------
stdp_synapse : Classical pair-based STDP without weight dependence.
stdp_triplet_synapse : Triplet STDP rule with better experimental fit.
vogels_sprekeler_synapse : Inhibitory STDP for E/I balance.
References
----------
.. [1] Jonke, Z., Habenschuss, S., & Maass, W. (2017). Feedback inhibition shapes
emergent computational properties of cortical microcircuit motifs.
*Journal of Neuroscience*, 37(35), 8511-8523. https://doi.org/10.1523/JNEUROSCI.2078-16.2017
.. [2] NEST Simulator source code: ``models/jonke_synapse.h`` and
``models/jonke_synapse.cpp`` (https://github.com/nest/nest-simulator).
.. [3] van Rossum, M. C., Bi, G. Q., & Turrigiano, G. G. (2000). Stable Hebbian learning
from spike timing-dependent plasticity. *Journal of Neuroscience*, 20(23),
8812-8821. [For multiplicative STDP theory]
"""
__module__ = 'brainpy.state'
HAS_DELAY = True
IS_PRIMARY = True
SUPPORTS_HPC = True
SUPPORTS_LBL = True
SUPPORTS_WFR = False
def __init__(
self,
weight: ArrayLike = 1.0,
delay: ArrayLike = 1.0,
delay_steps: ArrayLike = 1,
Kplus: ArrayLike = 0.0,
t_last_spike_ms: ArrayLike = 0.0,
alpha: ArrayLike = 1.0,
beta: ArrayLike = 0.0,
lambda_: ArrayLike = 0.01,
mu_plus: ArrayLike = 0.0,
mu_minus: ArrayLike = 0.0,
tau_plus: ArrayLike = 20.0,
Wmax: ArrayLike = 100.0,
name: str | None = None,
):
super().__init__(in_size=1, name=name)
self.weight = self._to_float_scalar(weight, name='weight')
self.delay = self._validate_positive_delay(delay)
self.delay_steps = self._validate_delay_steps(delay_steps)
self.Kplus = self._to_float_scalar(Kplus, name='Kplus')
if self.Kplus < 0.0:
raise ValueError('Kplus must be non-negative.')
self.t_last_spike_ms = self._to_float_scalar(t_last_spike_ms, name='t_last_spike_ms')
self.alpha = self._to_float_scalar(alpha, name='alpha')
self.beta = self._to_float_scalar(beta, name='beta')
self.lambda_ = self._to_float_scalar(lambda_, name='lambda_')
self.mu_plus = self._to_float_scalar(mu_plus, name='mu_plus')
self.mu_minus = self._to_float_scalar(mu_minus, name='mu_minus')
self.tau_plus = self._to_float_scalar(tau_plus, name='tau_plus')
self.Wmax = self._to_float_scalar(Wmax, name='Wmax')
@property
def properties(self) -> dict[str, Any]:
r"""NEST synapse model capability flags.
Returns
-------
dict[str, Any]
Dictionary with boolean capability flags:
- ``has_delay``: Supports delayed spike delivery (always True).
- ``is_primary``: Primary connection type for spike transmission (always True).
- ``supports_hpc``: Compatible with NEST's high-performance computing mode (True).
- ``supports_lbl``: Supports label-based connectivity (True).
- ``supports_wfr``: Supports waveform relaxation method (always False for
plasticity models).
"""
return {
'has_delay': self.HAS_DELAY,
'is_primary': self.IS_PRIMARY,
'supports_hpc': self.SUPPORTS_HPC,
'supports_lbl': self.SUPPORTS_LBL,
'supports_wfr': self.SUPPORTS_WFR,
}
[docs]
def get_status(self) -> dict[str, Any]:
r"""Retrieve complete synapse state snapshot (NEST GetStatus compatible).
Returns current values of all parameters, state variables, and model capabilities.
Output format matches NEST's ``GetStatus`` dictionary structure.
Returns
-------
dict[str, Any]
Dictionary containing:
- ``weight`` (float): Current synaptic efficacy.
- ``delay`` (float): Dendritic delay in ms.
- ``delay_steps`` (int): Event delivery delay in steps.
- ``Kplus`` (float): Current presynaptic trace value.
- ``t_last_spike_ms`` (float): Last presynaptic spike time in ms.
- ``alpha`` (float): Depression scaling factor.
- ``beta`` (float): Additive offset.
- ``lambda`` (float): Learning rate (key name uses NEST convention).
- ``mu_plus`` (float): Facilitation weight exponent.
- ``mu_minus`` (float): Depression weight exponent.
- ``tau_plus`` (float): Presynaptic trace time constant in ms.
- ``Wmax`` (float): Maximum weight bound.
- ``size_of`` (int): Memory footprint in bytes.
- Capability flags (``has_delay``, ``is_primary``, etc.).
Examples
--------
.. code-block:: python
>>> syn = bp.jonke_synapse(weight=3.5, lambda_=0.02, tau_plus=15.0)
>>> status = syn.get_status()
>>> print(status['weight'], status['lambda'], status['tau_plus'])
3.5 0.02 15.0
"""
return {
'weight': float(self.weight),
'delay': float(self.delay),
'delay_steps': int(self.delay_steps),
'Kplus': float(self.Kplus),
't_last_spike_ms': float(self.t_last_spike_ms),
'alpha': float(self.alpha),
'beta': float(self.beta),
'lambda': float(self.lambda_),
'mu_plus': float(self.mu_plus),
'mu_minus': float(self.mu_minus),
'tau_plus': float(self.tau_plus),
'Wmax': float(self.Wmax),
'size_of': int(self.__sizeof__()),
'has_delay': self.HAS_DELAY,
'is_primary': self.IS_PRIMARY,
'supports_hpc': self.SUPPORTS_HPC,
'supports_lbl': self.SUPPORTS_LBL,
'supports_wfr': self.SUPPORTS_WFR,
}
[docs]
def set_status(self, status: dict[str, Any] | None = None, **kwargs):
r"""Update synapse parameters and state (NEST SetStatus compatible).
Modifies any subset of parameters. Unspecified keys retain current values. Validates
all updates before applying (atomic operation). Accepts both dictionary argument and
keyword arguments (merged with kwargs taking precedence).
Parameters
----------
status : dict[str, Any] or None, optional
Dictionary of parameter updates. Keys match those in ``get_status()``. If None,
only ``kwargs`` are processed.
**kwargs
Additional parameter updates as keyword arguments. Merged with ``status`` dict.
If both ``status['key']`` and ``key=value`` are provided for the same parameter,
``kwargs`` takes precedence.
Raises
------
ValueError
- If ``Kplus`` is set to negative value.
- If ``delay <= 0`` or ``delay_steps < 1``.
- If both ``lambda`` and ``lambda_`` are provided with different values.
- If any scalar parameter is non-finite or has size ≠ 1.
Notes
-----
- The learning rate can be specified as either ``lambda`` (NEST convention) or
``lambda_`` (Python identifier). Both refer to the same internal state.
- Validation occurs after all updates are collected, ensuring atomic updates (all
succeed or all fail).
- Setting ``lambda=0`` disables plasticity without affecting trace dynamics.
Examples
--------
.. code-block:: python
>>> syn = bp.jonke_synapse(weight=5.0, lambda_=0.01)
>>> syn.set_status({'weight': 8.0, 'lambda': 0.005})
>>> syn.get_status()['weight']
8.0
>>> syn.get_status()['lambda']
0.005
**Keyword argument syntax:**
.. code-block:: python
>>> syn.set_status(Wmax=50.0, mu_plus=0.1)
>>> syn.get_status()['Wmax']
50.0
**Disable learning:**
.. code-block:: python
>>> syn.set_status(lambda_=0.0)
>>> # Synapse now transmits spikes but does not update weight
"""
updates = {}
if status is not None:
updates.update(status)
updates.update(kwargs)
if 'lambda' in updates and 'lambda_' in updates:
lv = self._to_float_scalar(updates['lambda'], name='lambda')
lvv = self._to_float_scalar(updates['lambda_'], name='lambda_')
if lv != lvv:
raise ValueError('lambda and lambda_ must be identical when both are provided.')
if 'weight' in updates:
self.weight = self._to_float_scalar(updates['weight'], name='weight')
if 'delay' in updates:
self.delay = self._validate_positive_delay(updates['delay'])
if 'delay_steps' in updates:
self.delay_steps = self._validate_delay_steps(updates['delay_steps'])
if 'Kplus' in updates:
self.Kplus = self._to_float_scalar(updates['Kplus'], name='Kplus')
if 't_last_spike_ms' in updates:
self.t_last_spike_ms = self._to_float_scalar(updates['t_last_spike_ms'], name='t_last_spike_ms')
if 'alpha' in updates:
self.alpha = self._to_float_scalar(updates['alpha'], name='alpha')
if 'beta' in updates:
self.beta = self._to_float_scalar(updates['beta'], name='beta')
if 'lambda' in updates:
self.lambda_ = self._to_float_scalar(updates['lambda'], name='lambda')
if 'lambda_' in updates:
self.lambda_ = self._to_float_scalar(updates['lambda_'], name='lambda_')
if 'mu_plus' in updates:
self.mu_plus = self._to_float_scalar(updates['mu_plus'], name='mu_plus')
if 'mu_minus' in updates:
self.mu_minus = self._to_float_scalar(updates['mu_minus'], name='mu_minus')
if 'tau_plus' in updates:
self.tau_plus = self._to_float_scalar(updates['tau_plus'], name='tau_plus')
if 'Wmax' in updates:
self.Wmax = self._to_float_scalar(updates['Wmax'], name='Wmax')
if self.Kplus < 0.0:
raise ValueError('Kplus must be non-negative.')
[docs]
def get(self, key: str = 'status'):
r"""Retrieve parameter or full status dictionary by key (NEST Get compatible).
Parameters
----------
key : str, default='status'
Parameter name or ``'status'`` for full dictionary. Valid keys include:
``'weight'``, ``'delay'``, ``'Kplus'``, ``'lambda'``, ``'tau_plus'``, etc.
Returns
-------
Any
- If ``key='status'``: full dictionary from ``get_status()``.
- Otherwise: scalar value of the requested parameter.
Raises
------
KeyError
If ``key`` is not ``'status'`` and not found in status dictionary.
Examples
--------
.. code-block:: python
>>> syn = bp.jonke_synapse(weight=7.0, tau_plus=25.0)
>>> syn.get('weight')
7.0
>>> syn.get('tau_plus')
25.0
>>> syn.get('status')['lambda']
0.01
"""
if key == 'status':
return self.get_status()
status = self.get_status()
if key in status:
return status[key]
raise KeyError(f'Unsupported key "{key}" for jonke_synapse.get().')
[docs]
def set_weight(self, weight: ArrayLike):
r"""Update synaptic weight (convenience method).
Parameters
----------
weight : float or array-like
New synaptic efficacy value. Must be scalar and finite.
Raises
------
ValueError
If weight is non-scalar or non-finite.
"""
self.weight = self._to_float_scalar(weight, name='weight')
[docs]
def set_delay(self, delay: ArrayLike):
r"""Update dendritic delay (convenience method).
Parameters
----------
delay : float or array-like
New delay in milliseconds. Must be positive and finite.
Raises
------
ValueError
If delay ≤ 0, non-scalar, or non-finite.
"""
self.delay = self._validate_positive_delay(delay)
[docs]
def set_delay_steps(self, delay_steps: ArrayLike):
r"""Update event delivery delay in steps (convenience method).
Parameters
----------
delay_steps : int or array-like
New delay in simulation time steps. Must be ≥ 1.
Raises
------
ValueError
If delay_steps < 1, non-integer, or non-finite.
"""
self.delay_steps = self._validate_delay_steps(delay_steps)
[docs]
def send(
self,
t_spike_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> dict[str, Any]:
r"""Process presynaptic spike with plasticity and return spike event payload.
Implements the full NEST ``jonke_synapse::send()`` protocol:
1. Retrieve postsynaptic spike history in delay-compensated window
2. Apply facilitation for each postsynaptic spike in history
3. Apply depression using current postsynaptic trace
4. Update presynaptic trace and timestamp
5. Return spike event dictionary with updated weight
This is the core method for spike-driven plasticity computation.
Parameters
----------
t_spike_ms : float or array-like
Current presynaptic spike time in milliseconds. Must be scalar and ≥
``t_last_spike_ms`` (non-decreasing spike times assumed).
target : object
Postsynaptic neuron or recorder object. Must implement:
- ``get_history(t1, t2) -> iterable``: Postsynaptic spike times in (t1, t2].
- ``get_K_value(t) -> float`` or ``get_k_value(t) -> float``: Depression trace
:math:`K_-(t)`.
receptor_type : int or array-like, default=0
Postsynaptic receptor channel identifier (e.g., 0=AMPA, 1=NMDA, 2=GABA_A).
Passed through to event payload without modification.
multiplicity : float or array-like, default=1.0
Spike event amplitude multiplier. Must be non-negative. Scales effective weight
in postsynaptic neuron. Typical use: probabilistic synapses or multi-vesicle
release.
delay : float or array-like or None, optional
Override dendritic delay for this spike (in ms). If None, uses ``self.delay``.
Affects history lookup window: :math:`(t_{\text{last}} - d,\; t - d]`.
delay_steps : int or array-like or None, optional
Override event delivery delay (in steps). If None, uses ``self.delay_steps``.
Determines when postsynaptic neuron receives the spike.
Returns
-------
dict[str, Any]
Spike event payload dictionary containing:
- ``weight`` (float): Updated synaptic efficacy after plasticity.
- ``delay`` (float): Dendritic delay used (ms).
- ``delay_steps`` (int): Event delivery delay (steps).
- ``receptor_type`` (int): Postsynaptic receptor channel.
- ``multiplicity`` (float): Spike amplitude multiplier.
- ``t_spike_ms`` (float): Presynaptic spike time (ms).
- ``Kminus`` (float): Postsynaptic trace value at depression check time.
- ``Kplus_pre`` (float): Presynaptic trace before update.
- ``Kplus_post`` (float): Presynaptic trace after update.
Raises
------
ValueError
- If ``t_spike_ms``, ``receptor_type``, or ``multiplicity`` are non-scalar.
- If ``multiplicity < 0``.
- If ``delay <= 0`` or ``delay_steps < 1`` (when overriding defaults).
AttributeError
If ``target`` does not implement required ``get_history()`` or ``get_K_value()``
methods.
TypeError
If history entries do not expose time via supported interface (see Notes).
Notes
-----
- **History entry format:** Each entry from ``target.get_history(t1, t2)`` must be:
* Object with ``.t_`` or ``.t`` attribute, OR
* Dictionary with ``'t_'`` or ``'t'`` key, OR
* Tuple/list where first element is time (float).
- **Delay semantics:** History lookup uses :math:`t - d` to account for
backpropagation delay. Event delivery uses ``delay_steps`` for forward propagation.
- **State mutation:** Updates ``self.weight``, ``self.Kplus``, and
``self.t_last_spike_ms`` in place. Not thread-safe without external synchronization.
- **Causality:** If :math:`t_{\text{spike}} < t_{\text{last}}`, trace decay may
produce negative exponential argument (mathematically valid but may indicate
simulation error).
Examples
--------
**Basic spike transmission:**
.. code-block:: python
>>> class PostNeuron:
... def get_history(self, t1, t2):
... return [] # No post-spikes
... def get_K_value(self, t):
... return 0.5 # Constant depression trace
>>>
>>> syn = bp.jonke_synapse(weight=5.0, lambda_=0.01, beta=0.02)
>>> target = PostNeuron()
>>> event = syn.send(t_spike_ms=10.0, target=target)
>>>
>>> print(f"Weight: {event['weight']:.3f}")
Weight: 4.980
>>> # Depression applied: dw = 0.01 * (-1.0 * 1.0 * 0.5 - 0.02) = -0.007
>>> # Weight bounded: max(0, 5.0 - 0.007) = 4.993 (approx, with exp factors)
**Spike-pair potentiation:**
.. code-block:: python
>>> class PostNeuron:
... def get_history(self, t1, t2):
... # Post-spike at t=12 ms (between t1 and t2)
... if t1 < 12.0 <= t2:
... return [{'t_': 12.0}]
... return []
... def get_K_value(self, t):
... return 0.0 # No depression trace yet
>>>
>>> syn = bp.jonke_synapse(weight=5.0, lambda_=0.01, tau_plus=20.0)
>>> target = PostNeuron()
>>>
>>> # First pre-spike at t=10 ms (before post@12)
>>> event1 = syn.send(t_spike_ms=10.0, target=target)
>>> print(f"Weight after pre@10: {event1['weight']:.3f}, Kplus: {event1['Kplus_post']:.3f}")
Weight after pre@10: 5.000, Kplus: 1.000
>>>
>>> # Second pre-spike at t=15 ms (post@12 now in history)
>>> event2 = syn.send(t_spike_ms=15.0, target=target)
>>> print(f"Weight after pre@15: {event2['weight']:.3f}")
Weight after pre@15: 5.009
>>> # Facilitation from Kplus(10) decayed to t=12: exp((10-13)/20) ≈ 0.861
**Override delay per spike:**
.. code-block:: python
>>> event = syn.send(
... t_spike_ms=20.0,
... target=target,
... delay=2.5,
... delay_steps=3
... )
>>> print(event['delay'], event['delay_steps'])
2.5 3
See Also
--------
to_spike_event : Alias for ``send()`` (NEST naming compatibility).
simulate_pre_spike_train : Process multiple spikes in sequence.
"""
t_spike = self._to_float_scalar(t_spike_ms, name='t_spike_ms')
dendritic_delay = self.delay if delay is None else self._validate_positive_delay(delay)
event_delay_steps = (
self.delay_steps
if delay_steps is None
else self._validate_delay_steps(delay_steps)
)
history_entries = self._get_history(
target,
self.t_last_spike_ms - dendritic_delay,
t_spike - dendritic_delay,
)
for entry in history_entries:
t_hist = self._extract_history_time(entry)
minus_dt = self.t_last_spike_ms - (t_hist + dendritic_delay)
kplus_t = self.Kplus * math.exp(minus_dt / self.tau_plus)
self.weight = self._facilitate(self.weight, kplus_t)
kminus = self._get_k_value(target, t_spike - dendritic_delay)
self.weight = self._depress(self.weight, kminus)
event = {
'weight': float(self.weight),
'delay': float(dendritic_delay),
'delay_steps': int(event_delay_steps),
'receptor_type': self._to_int_scalar(receptor_type, name='receptor_type'),
'multiplicity': self._validate_multiplicity(multiplicity),
't_spike_ms': float(t_spike),
'Kminus': float(kminus),
'Kplus_pre': float(self.Kplus),
}
self.Kplus = self.Kplus * math.exp((self.t_last_spike_ms - t_spike) / self.tau_plus) + 1.0
self.t_last_spike_ms = t_spike
event['Kplus_post'] = float(self.Kplus)
return event
[docs]
def to_spike_event(
self,
t_spike_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> dict[str, Any]:
r"""Alias for ``send()`` method (NEST naming compatibility).
Identical functionality to ``send()``. Provided for API consistency with NEST's
``Connection::to_spike_event()`` naming convention.
Parameters
----------
t_spike_ms : float or array-like
Presynaptic spike time in milliseconds.
target : object
Postsynaptic target with required interface.
receptor_type : int or array-like, default=0
Receptor channel identifier.
multiplicity : float or array-like, default=1.0
Spike amplitude multiplier.
delay : float or array-like or None, optional
Override dendritic delay (ms).
delay_steps : int or array-like or None, optional
Override event delivery delay (steps).
Returns
-------
dict[str, Any]
Spike event payload (see ``send()`` for structure).
See Also
--------
send : Primary spike processing method (full documentation).
"""
return self.send(
t_spike_ms=t_spike_ms,
target=target,
receptor_type=receptor_type,
multiplicity=multiplicity,
delay=delay,
delay_steps=delay_steps,
)
[docs]
def simulate_pre_spike_train(
self,
pre_spike_times_ms: ArrayLike,
target: Any,
receptor_type: ArrayLike = 0,
multiplicity: ArrayLike = 1.0,
delay: ArrayLike | None = None,
delay_steps: ArrayLike | None = None,
) -> list[dict[str, Any]]:
r"""Process sequence of presynaptic spikes and track weight evolution.
Convenience method for simulating complete spike train interactions. Sequentially
calls ``send()`` for each spike time, maintaining plasticity state across spikes.
Useful for analyzing STDP curves, weight trajectories, or protocol responses.
Parameters
----------
pre_spike_times_ms : array-like
Presynaptic spike times in milliseconds. Shape: ``(n_spikes,)`` or any shape
(will be flattened). Times need not be sorted but should be non-decreasing for
physically meaningful trace dynamics.
target : object
Postsynaptic target with required interface (same as ``send()``).
receptor_type : int or array-like, default=0
Receptor channel identifier (constant for all spikes).
multiplicity : float or array-like, default=1.0
Spike amplitude multiplier (constant for all spikes).
delay : float or array-like or None, optional
Override dendritic delay for all spikes (ms). If None, uses ``self.delay``.
delay_steps : int or array-like or None, optional
Override event delivery delay for all spikes (steps).
Returns
-------
list[dict[str, Any]]
Event payloads for each spike, in order. Length equals ``len(pre_spike_times_ms)``.
Each dictionary has structure documented in ``send()``.
Notes
-----
- **State evolution:** Synapse state (``weight``, ``Kplus``, ``t_last_spike_ms``)
evolves across the sequence. Final state reflects cumulative plasticity from all
spikes.
- **Performance:** For large spike trains (>10⁴ spikes), consider batching or
vectorized implementations if available.
- **Non-sorted times:** If times are unsorted, trace decay may produce unexpected
results (negative exponential arguments). Always verify input ordering.
Examples
--------
**STDP pairing protocol (pre-post and post-pre pairs):**
.. code-block:: python
>>> class PostNeuron:
... def __init__(self):
... self.post_spikes = [15.0, 35.0] # Post-spikes at t=15, 35
... def get_history(self, t1, t2):
... return [{'t_': t} for t in self.post_spikes if t1 < t <= t2]
... def get_K_value(self, t):
... return 1.0 if t >= 15.0 else 0.0
>>>
>>> syn = bp.jonke_synapse(weight=5.0, lambda_=0.01, tau_plus=20.0)
>>> target = PostNeuron()
>>>
>>> # Pre-spikes at t=[10, 20, 30, 40] ms
>>> events = syn.simulate_pre_spike_train(
... pre_spike_times_ms=[10.0, 20.0, 30.0, 40.0],
... target=target
... )
>>>
>>> # Track weight evolution
>>> for i, evt in enumerate(events):
... print(f"Spike {i}: t={evt['t_spike_ms']:.1f} ms, weight={evt['weight']:.4f}")
Spike 0: t=10.0 ms, weight=5.0000
Spike 1: t=20.0 ms, weight=5.0085
Spike 2: t=30.0 ms, weight=5.0112
Spike 3: t=40.0 ms, weight=5.0098
**Extract weight trajectory:**
.. code-block:: python
>>> weights = [evt['weight'] for evt in events]
>>> pre_traces = [evt['Kplus_post'] for evt in events]
>>> print(weights)
[5.0, 5.0085, 5.0112, 5.0098]
**Frequency-dependent plasticity:**
.. code-block:: python
>>> # High-frequency pre-spikes (10 Hz for 1 sec)
>>> times = np.arange(0, 1000, 100) # t=0, 100, 200, ..., 900 ms
>>> events = syn.simulate_pre_spike_train(times, target)
>>> print(f"Initial weight: {events[0]['weight']:.3f}")
>>> print(f"Final weight: {events[-1]['weight']:.3f}")
See Also
--------
send : Single spike processing (core method).
"""
dftype = brainstate.environ.dftype()
times = np.asarray(u.math.asarray(pre_spike_times_ms), dtype=dftype).reshape(-1)
events = []
for t in times:
events.append(
self.send(
t_spike_ms=float(t),
target=target,
receptor_type=receptor_type,
multiplicity=multiplicity,
delay=delay,
delay_steps=delay_steps,
)
)
return events
def _facilitate(self, w: float, kplus: float) -> float:
r"""Compute facilitation weight update with exponential weight dependence.
Applies potentiation rule: :math:`\Delta w = \lambda (\exp(\mu_+ w) K_+ - \beta)`.
Weight is clipped to :math:`[0, W_{\max}]` after update.
Parameters
----------
w : float
Current synaptic weight.
kplus : float
Effective presynaptic trace value (time-decayed :math:`K_+`).
Returns
-------
float
Updated weight after facilitation, hard-bounded to :math:`W_{\max}`.
Notes
-----
- Returns unchanged weight if :math:`\lambda = 0` (learning disabled).
- Does not enforce lower bound here (depression handles that).
- Exponential overflow for large :math:`\mu_+ w` will raise Python exception.
"""
if self.lambda_ == 0.0:
return w
k_w = math.exp(self.mu_plus * w)
dw = self.lambda_ * (k_w * kplus - self.beta)
new_w = w + dw
return new_w if new_w < self.Wmax else self.Wmax
def _depress(self, w: float, kminus: float) -> float:
r"""Compute depression weight update with exponential weight dependence.
Applies depression rule: :math:`\Delta w = \lambda (-\alpha \exp(\mu_- w) K_- - \beta)`.
Weight is clipped to :math:`[0, W_{\max}]` after update (lower bound at 0).
Parameters
----------
w : float
Current synaptic weight.
kminus : float
Postsynaptic depression trace value :math:`K_-(t - d)`.
Returns
-------
float
Updated weight after depression, hard-bounded to :math:`\geq 0`.
Notes
-----
- Returns unchanged weight if :math:`\lambda = 0` (learning disabled).
- Enforces non-negative weights (biological constraint for excitatory synapses).
- Exponential overflow for large :math:`\mu_- w` will raise Python exception.
"""
if self.lambda_ == 0.0:
return w
k_w = math.exp(self.mu_minus * w)
dw = self.lambda_ * (-self.alpha * k_w * kminus - self.beta)
new_w = w + dw
return new_w if new_w > 0.0 else 0.0
@staticmethod
def _get_history(target: Any, t1: float, t2: float):
r"""Retrieve postsynaptic spike history from target neuron.
Parameters
----------
target : object
Postsynaptic target with ``get_history()`` method.
t1 : float
Window start time (exclusive) in milliseconds.
t2 : float
Window end time (inclusive) in milliseconds.
Returns
-------
iterable
Postsynaptic spike entries in interval :math:`(t_1, t_2]`. Each entry format
must be compatible with ``_extract_history_time()``.
Raises
------
AttributeError
If target does not implement ``get_history(t1, t2)`` method.
"""
if hasattr(target, 'get_history'):
return target.get_history(float(t1), float(t2))
raise AttributeError(
'Target must provide get_history(t1, t2) for jonke_synapse.'
)
@staticmethod
def _extract_history_time(entry: Any) -> float:
r"""Extract spike time from history entry in flexible format.
Supports multiple entry formats for compatibility with various neuron implementations:
object attributes, dictionary keys, or tuple/list indexing.
Parameters
----------
entry : object or dict or tuple or list
History entry from ``get_history()``. Must encode spike time via:
- Attribute ``.t_`` or ``.t`` (object interface), OR
- Key ``'t_'`` or ``'t'`` (dict interface), OR
- First element ``[0]`` (sequence interface).
Returns
-------
float
Spike time in milliseconds.
Raises
------
TypeError
If entry does not conform to any supported format.
"""
if hasattr(entry, 't_'):
return float(entry.t_)
if hasattr(entry, 't'):
return float(entry.t)
if isinstance(entry, dict):
if 't_' in entry:
return float(entry['t_'])
if 't' in entry:
return float(entry['t'])
if isinstance(entry, (tuple, list)) and len(entry) >= 1:
return float(entry[0])
raise TypeError(
'History entry must expose a time as attribute t_/t, mapping key t_/t, or first tuple element.'
)
@staticmethod
def _get_k_value(target: Any, t: float) -> float:
r"""Retrieve postsynaptic depression trace value at specified time.
Parameters
----------
target : object
Postsynaptic neuron with depression trace interface.
t : float
Query time in milliseconds (typically :math:`t_{\text{spike}} - d`).
Returns
-------
float
Postsynaptic trace :math:`K_-(t)` (dimensionless, typically :math:`\geq 0`).
Raises
------
AttributeError
If target does not implement ``get_K_value(t)`` or ``get_k_value(t)`` method.
Notes
-----
- Method name is case-insensitive: accepts ``get_K_value`` or ``get_k_value``.
- No sign enforcement: negative trace values are permitted (though non-physical).
"""
if hasattr(target, 'get_K_value'):
return float(target.get_K_value(float(t)))
if hasattr(target, 'get_k_value'):
return float(target.get_k_value(float(t)))
raise AttributeError(
'Target must provide get_K_value(t) or get_k_value(t) for jonke_synapse.'
)
@staticmethod
def _to_float_scalar(value: ArrayLike, name: str) -> float:
r"""Convert array-like input to validated float scalar.
Handles saiunit Quantities, NumPy arrays, and Python scalars. Ensures result is
single finite value.
Parameters
----------
value : array-like
Input value (may be Quantity, array, or scalar).
name : str
Parameter name for error messages.
Returns
-------
float
Scalar float value.
Raises
------
ValueError
- If input size ≠ 1 (not scalar).
- If value is NaN or ±inf (non-finite).
"""
if isinstance(value, u.Quantity):
value = u.get_mantissa(value)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr[0])
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
return v
@staticmethod
def _to_int_scalar(value: ArrayLike, name: str) -> int:
r"""Convert array-like input to validated integer scalar.
Similar to ``_to_float_scalar`` but enforces integer values (within floating-point
tolerance 1e-12).
Parameters
----------
value : array-like
Input value (may be Quantity, array, or scalar).
name : str
Parameter name for error messages.
Returns
-------
int
Scalar integer value.
Raises
------
ValueError
- If input size ≠ 1 (not scalar).
- If value is NaN or ±inf (non-finite).
- If value is not integer-valued (e.g., 2.5 fails, 2.0 succeeds).
"""
if isinstance(value, u.Quantity):
value = u.get_mantissa(value)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
v = float(arr[0])
if not np.isfinite(v):
raise ValueError(f'{name} must be finite.')
iv = int(round(v))
if abs(v - iv) > 1e-12:
raise ValueError(f'{name} must be an integer value.')
return iv
@classmethod
def _validate_positive_delay(cls, value: ArrayLike) -> float:
r"""Validate and convert delay to positive float scalar.
Parameters
----------
value : array-like
Delay in milliseconds.
Returns
-------
float
Validated positive delay.
Raises
------
ValueError
If delay ≤ 0, non-scalar, or non-finite.
"""
d = cls._to_float_scalar(value, name='delay')
if d <= 0.0:
raise ValueError('delay must be > 0.')
return d
@classmethod
def _validate_delay_steps(cls, value: ArrayLike) -> int:
r"""Validate and convert delay_steps to integer ≥ 1.
Parameters
----------
value : array-like
Delay in simulation steps.
Returns
-------
int
Validated delay_steps ≥ 1.
Raises
------
ValueError
If delay_steps < 1, non-integer, non-scalar, or non-finite.
"""
d = cls._to_int_scalar(value, name='delay_steps')
if d < 1:
raise ValueError('delay_steps must be >= 1.')
return d
@classmethod
def _validate_multiplicity(cls, value: ArrayLike) -> float:
r"""Validate and convert multiplicity to non-negative float scalar.
Parameters
----------
value : array-like
Spike amplitude multiplier.
Returns
-------
float
Validated multiplicity ≥ 0.
Raises
------
ValueError
If multiplicity < 0, non-scalar, or non-finite.
"""
m = cls._to_float_scalar(value, name='multiplicity')
if m < 0.0:
raise ValueError('multiplicity must be >= 0.')
return m