Source code for brainpy_state._nest.urbanczik_synapse

import math
from typing import Any

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike

from ._base import NESTSynapse

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-

__all__ = [
    'urbanczik_synapse',
]


class urbanczik_synapse(NESTSynapse):
    r"""NEST-compatible ``urbanczik_synapse`` connection model.

    Plastic synapse implementing Urbanczik-Senn dendritic prediction error learning rule for
    supervised learning in multi-compartment neurons. This synapse requires target neurons
    that archive dendritic prediction errors (e.g., ``pp_cond_exp_mc_urbanczik``).

    **1. Mathematical Model**

    This implementation reproduces the connection-level semantics of NEST
    ``models/urbanczik_synapse.{h,cpp}``. The learning rule combines presynaptic spike traces
    with postsynaptic dendritic prediction errors to update synaptic weights through a
    low-pass filtered plasticity signal.

    **1.1. Presynaptic Traces**

    Two exponential traces track presynaptic spiking activity with different time constants:

    .. math::
       \tau_L^\mathrm{tr}(t) = \tau_L^\mathrm{tr}(t_{last}) \exp\left(\frac{t_{last}-t}{\tau_L}\right) + 1

    .. math::
       \tau_s^\mathrm{tr}(t) = \tau_s^\mathrm{tr}(t_{last}) \exp\left(\frac{t_{last}-t}{\tau_s}\right) + 1

    where :math:`t` is current spike time, :math:`t_{last}` is previous spike time,
    :math:`\tau_L = C_m / g_L` is membrane time constant, and :math:`\tau_s` is synaptic
    time constant (``tau_syn_ex`` for excitatory weights, ``tau_syn_in`` for inhibitory).

    **1.2. Plasticity Signal**

    For each postsynaptic dendritic prediction error entry :math:`(t_i, \Delta w_i)` in the
    history window :math:`(t_{last} - d, t - d]` (where :math:`d` is dendritic delay):

    .. math::
       \Pi_i = \left[\tau_L^\mathrm{tr}\exp\left(\frac{t_{last}-(t_i+d)}{\tau_L}\right)
               - \tau_s^\mathrm{tr}\exp\left(\frac{t_{last}-(t_i+d)}{\tau_s}\right)\right] \Delta w_i

    Two integrals accumulate plasticity contributions:

    .. math::
       \Pi_\mathrm{int} \leftarrow \Pi_\mathrm{int} + \sum_i \Pi_i

    .. math::
       \Pi_\mathrm{exp} \leftarrow \exp\left(\frac{t_{last}-t}{\tau_\Delta}\right)\Pi_\mathrm{exp}
                                    + \sum_i \exp\left(\frac{(t_i+d)-t}{\tau_\Delta}\right)\Pi_i

    where :math:`\tau_\Delta` is the low-pass filter time constant for weight changes.

    **1.3. Weight Update**

    The synaptic weight is updated using the filtered difference of integrals:

    .. math::
       w \leftarrow \mathrm{clip}\left(w_0 + \frac{15\,C_m\,\tau_s\,\eta}{g_L(\tau_L-\tau_s)}
                                        (\Pi_\mathrm{int} - \Pi_\mathrm{exp}), W_{min}, W_{max}\right)

    where :math:`w_0` is ``init_weight``, :math:`\eta` is learning rate, and :math:`C_m`, :math:`g_L`
    are membrane capacitance and leak conductance of the dendritic compartment.

    **2. NEST Implementation Fidelity**

    This class preserves NEST's exact send-ordering in ``urbanczik_synapse::send(...)``:

    1. Read archived history in :math:`(t_{last} - d, t - d]`
    2. Update :math:`\Pi_\mathrm{int}` and :math:`\Pi_\mathrm{exp}` integrals
    3. Compute new weight with clipping
    4. Emit spike event with updated weight
    5. Update :math:`\tau_L^\mathrm{tr}` and :math:`\tau_s^\mathrm{tr}` traces
    6. Set :math:`t_{last} = t`

    **3. Computational Considerations**

    - **Synaptic time constant selection**: Uses ``tau_syn_ex`` when ``weight > 0``, otherwise
      ``tau_syn_in``, matching NEST's current-weight-dependent branching
    - **Numerical precision**: Sub-grid timestamp offsets are ignored as in NEST
    - **Weight bounds**: Hard clipping to [Wmin, Wmax] after each update
    - **Sign constraints**: Weight, Wmin, and Wmax must all share the same sign (enforced at init
      and status updates)

    **4. Target Neuron Requirements**

    The target neuron must implement the Urbanczik archiving interface:

    - ``get_urbanczik_history(t1, t2, comp)``: Returns prediction error entries in :math:`(t1, t2]`
      for compartment ``comp`` (default 1 for dendritic)
    - ``get_g_L(comp)``, ``get_tau_L(comp)`` or ``get_C_m(comp)/get_g_L(comp)``
    - ``get_C_m(comp)``, ``get_tau_syn_ex(comp)``, ``get_tau_syn_in(comp)``

    History entries support multiple formats: objects with ``t_``/``dw_`` attributes (NEST-style),
    objects with ``t``/``dw`` attributes, dicts with those keys, or 2-tuples ``(t, dw)``.

    Parameters
    ----------
    weight : float or ArrayLike, optional
        Initial synaptic weight (dimensionless). Must share sign with Wmin and Wmax.
        Default: ``1.0``
    delay : float or ArrayLike, optional
        Dendritic delay in milliseconds used for history lookup. Must be ``> 0``.
        Default: ``1.0`` ms
    delay_steps : int or ArrayLike, optional
        Event delivery delay in simulation time steps. Must be ``>= 1``.
        Default: ``1``
    tau_Delta : float or ArrayLike, optional
        Time constant in milliseconds for low-pass filtering of weight changes. Controls
        the temporal smoothing of plasticity signals. Larger values produce slower, more
        stable learning. Default: ``100.0`` ms
    eta : float or ArrayLike, optional
        Learning rate (dimensionless). Scales the magnitude of weight updates. Typical range
        0.01–0.1 for cortical models. Default: ``0.07``
    Wmin : float or ArrayLike, optional
        Lower bound of synaptic weight (hard clipping). Must share sign with weight and Wmax.
        Default: ``0.0``
    Wmax : float or ArrayLike, optional
        Upper bound of synaptic weight (hard clipping). Must share sign with weight and Wmin.
        Default: ``100.0``
    PI_integral : float or ArrayLike, optional
        Initial value of unfiltered accumulated plasticity integral :math:`\Pi_\mathrm{int}`.
        Default: ``0.0``
    PI_exp_integral : float or ArrayLike, optional
        Initial value of exponentially filtered plasticity integral :math:`\Pi_\mathrm{exp}`.
        Default: ``0.0``
    tau_L_trace : float or ArrayLike, optional
        Initial state of :math:`\tau_L` presynaptic trace. Default: ``0.0``
    tau_s_trace : float or ArrayLike, optional
        Initial state of :math:`\tau_s` presynaptic trace. Default: ``0.0``
    t_last_spike_ms : float or ArrayLike, optional
        Last presynaptic spike time in milliseconds. Default: ``-1.0`` (no previous spike)
    name : str, optional
        Instance name for debugging and logging. Default: ``None``

    Parameter Mapping
    -----------------
    NEST parameter mappings to this implementation:

    =============================  =========================================================
    NEST Parameter                 brainpy.state Attribute
    =============================  =========================================================
    ``weight``                     ``weight`` (current synaptic weight)
    ``delay``                      ``delay`` (dendritic delay for history)
    ``tau_Delta``                  ``tau_Delta`` (low-pass time constant)
    ``eta``                        ``eta`` (learning rate)
    ``Wmin``                       ``Wmin`` (lower weight bound)
    ``Wmax``                       ``Wmax`` (upper weight bound)
    ``init_weight``                ``init_weight`` (baseline weight for updates)
    ``receptor_type``              passed per spike event
    ``t_lastspike``                ``t_last_spike_ms``
    =============================  =========================================================

    Raises
    ------
    ValueError
        If ``delay`` is not positive, ``delay_steps < 1``, or weight/Wmin/Wmax have
        inconsistent signs
    AttributeError
        If target neuron does not implement required Urbanczik archiving interface methods

    Notes
    -----
    - **set_status() behavior**: Following NEST, ``set_status()`` always resets ``init_weight``
      to the current ``weight`` unless explicitly provided in the status dict
    - **Multiplicity**: Spike multiplicity is validated but not used in plasticity computation
      (NEST compatibility)
    - **Sub-grid timing**: Precise spike time offsets within a time step are ignored in this
      plasticity rule (consistent with NEST implementation)

    Examples
    --------
    Basic synapse creation and spike processing:

    .. code-block:: python

       >>> import brainpy.state as bp
       >>> # Create synapse with moderate learning rate
       >>> syn = bp.urbanczik_synapse(
       ...     weight=0.5,
       ...     delay=1.0,
       ...     tau_Delta=80.0,
       ...     eta=0.05,
       ...     Wmin=0.0,
       ...     Wmax=10.0
       ... )
       >>>
       >>> # Check initial status
       >>> status = syn.get_status()
       >>> print(f"Initial weight: {status['weight']}")
       Initial weight: 0.5
       >>> print(f"Learning rate: {status['eta']}")
       Learning rate: 0.05

    Processing spike trains with mock target neuron:

    .. code-block:: python

       >>> class MockUrbanczikNeuron:
       ...     def __init__(self):
       ...         self.history = []
       ...
       ...     def get_urbanczik_history(self, t1, t2, comp):
       ...         # Return prediction errors in (t1, t2]
       ...         return [(t, dw) for t, dw in self.history if t1 < t <= t2]
       ...
       ...     def get_g_L(self, comp): return 10.0  # nS
       ...     def get_tau_L(self, comp): return 20.0  # ms
       ...     def get_C_m(self, comp): return 200.0  # pF
       ...     def get_tau_syn_ex(self, comp): return 2.0  # ms
       ...     def get_tau_syn_in(self, comp): return 5.0  # ms
       >>>
       >>> target = MockUrbanczikNeuron()
       >>>
       >>> # Simulate presynaptic spike at t=10 ms
       >>> event = syn.send(t_spike_ms=10.0, target=target)
       >>> print(f"Weight after spike: {event['weight']:.3f}")
       Weight after spike: 0.500
       >>>
       >>> # Add dendritic prediction error and process another spike
       >>> target.history.append((8.0, 0.1))  # (time_ms, delta_w)
       >>> event = syn.send(t_spike_ms=20.0, target=target)
       >>> print(f"Weight after learning: {event['weight']:.3f}")
       Weight after learning: 0.503

    Weight bound enforcement:

    .. code-block:: python

       >>> syn = bp.urbanczik_synapse(weight=5.0, Wmin=0.0, Wmax=10.0)
       >>> syn.set_status({'weight': 12.0})  # Exceeds Wmax
       >>> print(syn.get('weight'))
       10.0
       >>> syn.set_status({'weight': -1.0})  # Violates sign constraint
       Traceback (most recent call last):
       ValueError: Weight and Wmax must have same sign.

    References
    ----------
    .. [1] Urbanczik R, Senn W (2014). Learning by the dendritic prediction of somatic spiking.
           *Neuron* 81(3):521-528. DOI: 10.1016/j.neuron.2013.11.030
    .. [2] Jordan J, Sacramento J, Wybo WAM, et al. (2021). Conductance-based dendrites perform
           reliability-weighted opinion pooling. *arXiv* 2109.02040.
    .. [3] NEST source: ``models/urbanczik_synapse.h`` and ``models/urbanczik_synapse.cpp``
           (NEST Simulator version 3.9+)

    See Also
    --------
    pp_cond_exp_mc_urbanczik : Multi-compartment neuron supporting Urbanczik archiving
    stdp_synapse : Classical spike-timing dependent plasticity synapse
    """

    __module__ = 'brainpy.state'

    HAS_DELAY = True
    IS_PRIMARY = True
    REQUIRES_URBANCZIK_ARCHIVING = True
    SUPPORTS_HPC = True
    SUPPORTS_LBL = True
    SUPPORTS_WFR = True

    DENDRITIC_COMPARTMENT = 1

    def __init__(
        self,
        weight: ArrayLike = 1.0,
        delay: ArrayLike = 1.0,
        delay_steps: ArrayLike = 1,
        tau_Delta: ArrayLike = 100.0,
        eta: ArrayLike = 0.07,
        Wmin: ArrayLike = 0.0,
        Wmax: ArrayLike = 100.0,
        PI_integral: ArrayLike = 0.0,
        PI_exp_integral: ArrayLike = 0.0,
        tau_L_trace: ArrayLike = 0.0,
        tau_s_trace: ArrayLike = 0.0,
        t_last_spike_ms: ArrayLike = -1.0,
        name: str | None = None,
    ):
        super().__init__(in_size=1, name=name)

        self.weight = self._to_float_scalar(weight, name='weight')
        self.delay = self._validate_positive_delay(delay)
        self.delay_steps = self._validate_delay_steps(delay_steps)

        self.tau_Delta = self._to_float_scalar(tau_Delta, name='tau_Delta')
        self.eta = self._to_float_scalar(eta, name='eta')
        self.Wmin = self._to_float_scalar(Wmin, name='Wmin')
        self.Wmax = self._to_float_scalar(Wmax, name='Wmax')

        self.PI_integral = self._to_float_scalar(PI_integral, name='PI_integral')
        self.PI_exp_integral = self._to_float_scalar(PI_exp_integral, name='PI_exp_integral')
        self.tau_L_trace = self._to_float_scalar(tau_L_trace, name='tau_L_trace')
        self.tau_s_trace = self._to_float_scalar(tau_s_trace, name='tau_s_trace')
        self.t_last_spike_ms = self._to_float_scalar(t_last_spike_ms, name='t_last_spike_ms')

        # NEST initializes init_weight_ from weight_.
        self.init_weight = float(self.weight)

        self._check_weight_sign_constraints()

    @property
    def properties(self) -> dict[str, Any]:
        r"""Return synapse capability flags.

        Returns
        -------
        dict[str, Any]
            Dictionary with boolean flags:
            - ``has_delay``: Connection supports transmission delays (always True)
            - ``is_primary``: This is a primary connection type (always True)
            - ``requires_urbanczik_archiving``: Target must implement Urbanczik history (always True)
            - ``supports_hpc``: Compatible with high-performance computing features (always True)
            - ``supports_lbl``: Supports local branching levels in dendritic trees (always True)
            - ``supports_wfr``: Supports waveform relaxation methods (always True)

        Notes
        -----
        These flags match NEST synapse property conventions for integration with NEST-compatible
        simulation infrastructure.
        """
        return {
            'has_delay': self.HAS_DELAY,
            'is_primary': self.IS_PRIMARY,
            'requires_urbanczik_archiving': self.REQUIRES_URBANCZIK_ARCHIVING,
            'supports_hpc': self.SUPPORTS_HPC,
            'supports_lbl': self.SUPPORTS_LBL,
            'supports_wfr': self.SUPPORTS_WFR,
        }

[docs] def get_status(self) -> dict[str, Any]: r"""Return current synapse state and parameters. Returns ------- dict[str, Any] Complete synapse state dictionary. Keys include: ``weight`` (float), ``delay`` (float), ``delay_steps`` (int), ``tau_Delta`` (float), ``eta`` (float), ``Wmin`` (float), ``Wmax`` (float), ``init_weight`` (float), ``PI_integral`` (float), ``PI_exp_integral`` (float), ``tau_L_trace`` (float), ``tau_s_trace`` (float), ``t_last_spike_ms`` (float), ``size_of`` (int), and capability flags (``has_delay``, ``is_primary``, etc.). Notes ----- Compatible with NEST ``GetStatus()`` semantics. All floating-point values are guaranteed finite (no NaN or infinity). Examples -------- .. code-block:: python >>> syn = bp.urbanczik_synapse(weight=2.5, eta=0.08) >>> status = syn.get_status() >>> print(f"Weight: {status['weight']}, Learning rate: {status['eta']}") Weight: 2.5, Learning rate: 0.08 """ return { 'weight': float(self.weight), 'delay': float(self.delay), 'delay_steps': int(self.delay_steps), 'tau_Delta': float(self.tau_Delta), 'eta': float(self.eta), 'Wmin': float(self.Wmin), 'Wmax': float(self.Wmax), 'init_weight': float(self.init_weight), 'PI_integral': float(self.PI_integral), 'PI_exp_integral': float(self.PI_exp_integral), 'tau_L_trace': float(self.tau_L_trace), 'tau_s_trace': float(self.tau_s_trace), 't_last_spike_ms': float(self.t_last_spike_ms), 'size_of': int(self.__sizeof__()), 'has_delay': self.HAS_DELAY, 'is_primary': self.IS_PRIMARY, 'requires_urbanczik_archiving': self.REQUIRES_URBANCZIK_ARCHIVING, 'supports_hpc': self.SUPPORTS_HPC, 'supports_lbl': self.SUPPORTS_LBL, 'supports_wfr': self.SUPPORTS_WFR, }
[docs] def set_status(self, status: dict[str, Any] | None = None, **kwargs): r"""Update synapse parameters and state. Parameters ---------- status : dict[str, Any], optional Dictionary of parameter updates. Keys match those returned by ``get_status()``. **kwargs Additional parameter updates as keyword arguments. These override any values in ``status`` dict if both are provided. Raises ------ ValueError If updated parameters violate sign constraints (weight, Wmin, Wmax must share sign), if delay is not positive, if delay_steps < 1, or if any value is non-finite. Notes ----- - **NEST compatibility**: Following NEST ``SetStatus()`` semantics, ``init_weight`` is automatically reset to the current ``weight`` after updates unless explicitly provided in the update dict - All updatable parameters: ``weight``, ``delay``, ``delay_steps``, ``tau_Delta``, ``eta``, ``Wmin``, ``Wmax``, ``PI_integral``, ``PI_exp_integral``, ``tau_L_trace``, ``tau_s_trace``, ``t_last_spike_ms``, ``init_weight`` - Sign constraint validation occurs after all updates are applied Examples -------- Update single parameter: .. code-block:: python >>> syn = bp.urbanczik_synapse(weight=1.0) >>> syn.set_status(eta=0.1) >>> print(syn.get('eta')) 0.1 Batch update with dict: .. code-block:: python >>> syn.set_status({'weight': 2.0, 'tau_Delta': 120.0}) >>> status = syn.get_status() >>> print(f"Weight: {status['weight']}, tau_Delta: {status['tau_Delta']}") Weight: 2.0, tau_Delta: 120.0 Keyword arguments override dict values: .. code-block:: python >>> syn.set_status({'eta': 0.05}, eta=0.08) >>> print(syn.get('eta')) 0.08 """ updates = {} if status is not None: updates.update(status) updates.update(kwargs) if 'weight' in updates: self.weight = self._to_float_scalar(updates['weight'], name='weight') if 'delay' in updates: self.delay = self._validate_positive_delay(updates['delay']) if 'delay_steps' in updates: self.delay_steps = self._validate_delay_steps(updates['delay_steps']) if 'tau_Delta' in updates: self.tau_Delta = self._to_float_scalar(updates['tau_Delta'], name='tau_Delta') if 'eta' in updates: self.eta = self._to_float_scalar(updates['eta'], name='eta') if 'Wmin' in updates: self.Wmin = self._to_float_scalar(updates['Wmin'], name='Wmin') if 'Wmax' in updates: self.Wmax = self._to_float_scalar(updates['Wmax'], name='Wmax') if 'PI_integral' in updates: self.PI_integral = self._to_float_scalar(updates['PI_integral'], name='PI_integral') if 'PI_exp_integral' in updates: self.PI_exp_integral = self._to_float_scalar(updates['PI_exp_integral'], name='PI_exp_integral') if 'tau_L_trace' in updates: self.tau_L_trace = self._to_float_scalar(updates['tau_L_trace'], name='tau_L_trace') if 'tau_s_trace' in updates: self.tau_s_trace = self._to_float_scalar(updates['tau_s_trace'], name='tau_s_trace') if 't_last_spike_ms' in updates: self.t_last_spike_ms = self._to_float_scalar(updates['t_last_spike_ms'], name='t_last_spike_ms') if 'init_weight' in updates: self.init_weight = self._to_float_scalar(updates['init_weight'], name='init_weight') else: # NEST set_status() always syncs init_weight_ to current weight_. self.init_weight = float(self.weight) self._check_weight_sign_constraints()
[docs] def get(self, key: str = 'status'): r"""Retrieve synapse parameter or full status. Parameters ---------- key : str, optional Parameter name or ``'status'`` for complete state dict. Valid keys match those in ``get_status()`` return dict. Default: ``'status'`` Returns ------- Any If ``key='status'``, returns full status dict. Otherwise returns scalar value of requested parameter. Raises ------ KeyError If ``key`` is not a recognized parameter name. Examples -------- .. code-block:: python >>> syn = bp.urbanczik_synapse(weight=1.5, eta=0.06) >>> syn.get('weight') 1.5 >>> syn.get('eta') 0.06 >>> full_status = syn.get('status') >>> print(type(full_status)) <class 'dict'> """ if key == 'status': return self.get_status() status = self.get_status() if key in status: return status[key] raise KeyError(f'Unsupported key "{key}" for urbanczik_synapse.get().')
[docs] def set_weight(self, weight: ArrayLike): r"""Update synaptic weight with sign constraint validation. Parameters ---------- weight : float or ArrayLike New synaptic weight. Must be finite scalar and share sign with Wmin/Wmax. Raises ------ ValueError If weight violates sign constraints or is non-finite/non-scalar. """ self.weight = self._to_float_scalar(weight, name='weight') self._check_weight_sign_constraints()
[docs] def set_delay(self, delay: ArrayLike): r"""Update dendritic delay. Parameters ---------- delay : float or ArrayLike New dendritic delay in milliseconds. Must be ``> 0``. Raises ------ ValueError If delay is not positive, non-finite, or non-scalar. """ self.delay = self._validate_positive_delay(delay)
[docs] def set_delay_steps(self, delay_steps: ArrayLike): r"""Update event delivery delay in time steps. Parameters ---------- delay_steps : int or ArrayLike New event delivery delay. Must be ``>= 1``. Raises ------ ValueError If delay_steps is less than 1 or not an integer value. """ self.delay_steps = self._validate_delay_steps(delay_steps)
[docs] def send( self, t_spike_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> dict[str, Any]: r"""Process presynaptic spike, update weight via dendritic prediction errors, and emit event. This method implements the core Urbanczik-Senn plasticity computation. It retrieves postsynaptic prediction error history from the target neuron, updates presynaptic traces and plasticity integrals, computes new synaptic weight, and returns spike event payload. **Computation Order (NEST-exact)**: 1. Query target's ``get_urbanczik_history()`` for entries in :math:`(t_{last} - d, t - d]` 2. For each history entry, compute :math:`\Pi_i` using current trace states 3. Update :math:`\Pi_\mathrm{int}` and :math:`\Pi_\mathrm{exp}` accumulators 4. Compute new weight with clipping to [Wmin, Wmax] 5. Create spike event dict with updated weight 6. Update :math:`\tau_L^\mathrm{tr}` and :math:`\tau_s^\mathrm{tr}` traces 7. Set :math:`t_{last} = t` Parameters ---------- t_spike_ms : float or ArrayLike Presynaptic spike time in milliseconds. Must be finite scalar. target : Any Postsynaptic neuron implementing Urbanczik archiving interface. Must provide: ``get_urbanczik_history(t1, t2, comp)``, ``get_g_L(comp)``, ``get_tau_L(comp)`` (or ``get_C_m(comp)``), ``get_tau_syn_ex(comp)``, ``get_tau_syn_in(comp)``. receptor_type : int or ArrayLike, optional Receptor channel index on target neuron. Default: ``0`` multiplicity : float or ArrayLike, optional Spike event multiplicity (validated but not used in plasticity). Must be ``>= 0``. Default: ``1.0`` delay : float or ArrayLike, optional Override dendritic delay for this spike (milliseconds, must be ``> 0``). If ``None``, uses ``self.delay``. Default: ``None`` delay_steps : int or ArrayLike, optional Override event delivery delay for this spike (steps, must be ``>= 1``). If ``None``, uses ``self.delay_steps``. Default: ``None`` Returns ------- dict[str, Any] Spike event dictionary. Keys: ``weight`` (float, updated synaptic weight), ``delay`` (float, dendritic delay in ms), ``delay_steps`` (int), ``receptor_type`` (int), ``multiplicity`` (float), ``t_spike_ms`` (float, spike time in ms), ``tau_s_ms`` (float, synaptic time constant used), ``PI_integral`` (float), ``PI_exp_integral`` (float), ``tau_L_trace_post`` (float), ``tau_s_trace_post`` (float). Raises ------ AttributeError If target does not implement required Urbanczik archiving interface methods. ValueError If any parameter is non-finite, delay is not positive, delay_steps < 1, or multiplicity < 0. Notes ----- - **Synaptic time constant selection**: Uses ``tau_syn_ex`` if current ``weight > 0``, otherwise ``tau_syn_in``. This branching matches NEST's current-weight-dependent logic. - **Sub-grid precision**: Sub-timestep spike offsets are ignored in plasticity computation (NEST compatibility). - **History window**: Query interval :math:`(t_{last} - d, t - d]` is open on left, closed on right, matching NEST's ``get_history()`` semantics. - **Multiplicity**: Validated but not incorporated into weight update (NEST behavior). Examples -------- Process single spike: .. code-block:: python >>> syn = bp.urbanczik_synapse(weight=1.0, eta=0.05) >>> event = syn.send(t_spike_ms=10.0, target=mock_neuron) >>> print(f"Updated weight: {event['weight']:.3f}") Updated weight: 1.003 Override delay for specific spike: .. code-block:: python >>> event = syn.send(t_spike_ms=20.0, target=mock_neuron, delay=2.0) >>> print(f"Delay used: {event['delay']} ms") Delay used: 2.0 ms Access trace states post-update: .. code-block:: python >>> event = syn.send(t_spike_ms=30.0, target=mock_neuron) >>> print(f"tau_L trace: {event['tau_L_trace_post']:.3f}") tau_L trace: 1.105 """ t_spike = self._to_float_scalar(t_spike_ms, name='t_spike_ms') dendritic_delay = self.delay if delay is None else self._validate_positive_delay(delay) event_delay_steps = ( self.delay_steps if delay_steps is None else self._validate_delay_steps(delay_steps) ) comp = self.DENDRITIC_COMPARTMENT history_entries = self._get_urbanczik_history( target, self.t_last_spike_ms - dendritic_delay, t_spike - dendritic_delay, comp=comp, ) g_L = self._get_compartment_value(target, ['get_g_L', 'get_g_l'], comp=comp, field='g_L') tau_L = self._get_tau_L(target, comp=comp) C_m = self._get_compartment_value(target, ['get_C_m', 'get_c_m'], comp=comp, field='C_m') tau_syn_ex = self._get_compartment_value( target, ['get_tau_syn_ex', 'get_tau_syn_exc'], comp=comp, field='tau_syn_ex', ) tau_syn_in = self._get_compartment_value( target, ['get_tau_syn_in', 'get_tau_syn_inh'], comp=comp, field='tau_syn_in', ) tau_s = tau_syn_ex if self.weight > 0.0 else tau_syn_in dPI_exp_integral = 0.0 for entry in history_entries: t_hist, dw = self._extract_history_entry(entry) t_up = t_hist + dendritic_delay minus_delta_t_up = self.t_last_spike_ms - t_up minus_t_down = t_up - t_spike PI = ( self.tau_L_trace * math.exp(minus_delta_t_up / tau_L) - self.tau_s_trace * math.exp(minus_delta_t_up / tau_s) ) * dw self.PI_integral += PI dPI_exp_integral += math.exp(minus_t_down / self.tau_Delta) * PI self.PI_exp_integral = ( math.exp((self.t_last_spike_ms - t_spike) / self.tau_Delta) * self.PI_exp_integral + dPI_exp_integral ) self.weight = self.PI_integral - self.PI_exp_integral self.weight = self.init_weight + ( self.weight * 15.0 * C_m * tau_s * self.eta / (g_L * (tau_L - tau_s)) ) if self.weight > self.Wmax: self.weight = self.Wmax elif self.weight < self.Wmin: self.weight = self.Wmin event = { 'weight': float(self.weight), 'delay': float(dendritic_delay), 'delay_steps': int(event_delay_steps), 'receptor_type': self._to_int_scalar(receptor_type, name='receptor_type'), 'multiplicity': self._validate_multiplicity(multiplicity), 't_spike_ms': float(t_spike), 'tau_s_ms': float(tau_s), 'PI_integral': float(self.PI_integral), 'PI_exp_integral': float(self.PI_exp_integral), } self.tau_L_trace = self.tau_L_trace * math.exp((self.t_last_spike_ms - t_spike) / tau_L) + 1.0 self.tau_s_trace = self.tau_s_trace * math.exp((self.t_last_spike_ms - t_spike) / tau_s) + 1.0 self.t_last_spike_ms = t_spike event['tau_L_trace_post'] = float(self.tau_L_trace) event['tau_s_trace_post'] = float(self.tau_s_trace) return event
[docs] def to_spike_event( self, t_spike_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> dict[str, Any]: r"""Alias for ``send()`` method. This method provides an alternative API name for spike event generation, matching common naming conventions in event-based simulators. Parameters ---------- t_spike_ms : float or ArrayLike Presynaptic spike time in milliseconds. target : Any Postsynaptic neuron with Urbanczik archiving interface. receptor_type : int or ArrayLike, optional Receptor channel index. Default: ``0`` multiplicity : float or ArrayLike, optional Spike event multiplicity. Default: ``1.0`` delay : float or ArrayLike, optional Override dendritic delay (milliseconds). Default: ``None`` (use ``self.delay``) delay_steps : int or ArrayLike, optional Override event delivery delay (steps). Default: ``None`` (use ``self.delay_steps``) Returns ------- dict[str, Any] Spike event dictionary identical to ``send()`` return value. See Also -------- send : Primary spike processing method with full documentation. """ return self.send( t_spike_ms=t_spike_ms, target=target, receptor_type=receptor_type, multiplicity=multiplicity, delay=delay, delay_steps=delay_steps, )
[docs] def simulate_pre_spike_train( self, pre_spike_times_ms: ArrayLike, target: Any, receptor_type: ArrayLike = 0, multiplicity: ArrayLike = 1.0, delay: ArrayLike | None = None, delay_steps: ArrayLike | None = None, ) -> list[dict[str, Any]]: r"""Process sequence of presynaptic spikes and return event history. Convenience method for batch processing of spike trains. Each spike is processed sequentially using ``send()``, with synapse state (weight, traces, integrals) updating after each spike. Parameters ---------- pre_spike_times_ms : array_like 1D array or sequence of presynaptic spike times in milliseconds. Will be flattened if multidimensional. Order matters: earlier spikes affect later ones through trace dynamics. target : Any Postsynaptic neuron with Urbanczik archiving interface. receptor_type : int or ArrayLike, optional Receptor channel index applied to all spikes. Default: ``0`` multiplicity : float or ArrayLike, optional Spike multiplicity applied to all spikes. Default: ``1.0`` delay : float or ArrayLike, optional Override dendritic delay for all spikes (milliseconds). Default: ``None`` delay_steps : int or ArrayLike, optional Override event delivery delay for all spikes (steps). Default: ``None`` Returns ------- list[dict[str, Any]] List of spike event dictionaries, one per input spike, in chronological order. Each dict has same structure as ``send()`` return value. Notes ----- - **Stateful processing**: Synapse internal state (weight, traces) persists across spikes in the train. Final state reflects cumulative plasticity effects. - **Performance**: For large spike trains (>10000 spikes), consider batching or vectorization depending on target neuron implementation. - **Temporal ordering**: Input spikes should typically be sorted in ascending time order for biologically realistic plasticity dynamics. Examples -------- Process spike train: .. code-block:: python >>> syn = bp.urbanczik_synapse(weight=1.0, eta=0.05) >>> spike_times = [10.0, 15.0, 20.0, 25.0] >>> events = syn.simulate_pre_spike_train(spike_times, target=mock_neuron) >>> weights = [e['weight'] for e in events] >>> print(f"Weight trajectory: {weights}") Weight trajectory: [1.002, 1.005, 1.008, 1.011] Extract trace evolution: .. code-block:: python >>> import numpy as np >>> spike_times = np.arange(0, 100, 5.0) >>> events = syn.simulate_pre_spike_train(spike_times, target=mock_neuron) >>> tau_L_traces = [e['tau_L_trace_post'] for e in events] >>> print(f"Final tau_L trace: {tau_L_traces[-1]:.3f}") Final tau_L trace: 2.456 See Also -------- send : Process single spike with full control over parameters. """ dftype = brainstate.environ.dftype() times = np.asarray(u.math.asarray(pre_spike_times_ms), dtype=dftype).reshape(-1) events = [] for t in times: events.append( self.send( t_spike_ms=float(t), target=target, receptor_type=receptor_type, multiplicity=multiplicity, delay=delay, delay_steps=delay_steps, ) ) return events
def _check_weight_sign_constraints(self): r"""Validate that weight, Wmin, and Wmax share consistent signs. Raises ------ ValueError If weight and Wmin have different signs, or if weight and Wmax have different signs. Notes ----- Sign check logic and error messages exactly match NEST ``urbanczik_synapse::set_status()``. Uses NEST's specific sign comparison semantics via ``_sign_like_wmax()``. """ # Keep sign checks/message text aligned with NEST urbanczik_synapse::set_status. if bool(np.signbit(self.weight)) != bool(np.signbit(self.Wmax)): raise ValueError('Weight and Wmin must have same sign.') if self._sign_like_wmax(self.weight) != self._sign_like_wmax(self.Wmax): raise ValueError('Weight and Wmax must have same sign.') @staticmethod def _sign_like_wmax(x: float) -> int: r"""NEST-compatible sign function for weight bound validation. Parameters ---------- x : float Value to check sign of. Returns ------- int dftype = brainstate.environ.dftype() ``1`` if x > 0, ``-1`` if x <= 0 (matching NEST's sign semantics for Wmax). Notes ----- This function differs from standard sign() by treating zero as negative, consistent with NEST's weight bound validation logic. """ return int((x > 0.0) - (x <= 0.0)) @staticmethod def _get_urbanczik_history(target: Any, t1: float, t2: float, comp: int): r"""Retrieve dendritic prediction error history from target neuron. Parameters ---------- target : Any Postsynaptic neuron object implementing ``get_urbanczik_history()`` method. t1 : float Start time in milliseconds (exclusive). t2 : float End time in milliseconds (inclusive). comp : int Compartment index (typically 1 for dendritic compartment). Returns ------- list List of history entries in interval (t1, t2]. Each entry can be: - Object with ``t_``/``dw_`` attributes (NEST-style) - Object with ``t``/``dw`` attributes - Dict with ``'t'``/``'dw'`` keys - 2-tuple ``(time, delta_w)`` Returns empty list if no entries or if target returns None. Raises ------ AttributeError If target does not provide ``get_urbanczik_history()`` method. Notes ----- First attempts to call with ``(t1, t2, comp)`` signature. If that raises TypeError (target doesn't accept comp argument), falls back to ``(t1, t2)`` signature. """ fn = getattr(target, 'get_urbanczik_history', None) if fn is None or not callable(fn): raise AttributeError( 'Target must provide get_urbanczik_history(t1, t2, comp) for urbanczik_synapse.' ) try: history = fn(float(t1), float(t2), int(comp)) except TypeError: history = fn(float(t1), float(t2)) if history is None: return [] return history @classmethod def _get_tau_L(cls, target: Any, comp: int) -> float: r"""Retrieve membrane time constant tau_L from target neuron. Parameters ---------- target : Any Postsynaptic neuron object. comp : int Compartment index. Returns ------- float Membrane time constant in milliseconds (tau_L = C_m / g_L). Raises ------ AttributeError If target provides neither ``get_tau_L()`` nor both ``get_C_m()`` and ``get_g_L()``. Notes ----- Tries the following in order: 1. Direct ``get_tau_L(comp)`` or ``get_tau_l(comp)`` method 2. If method doesn't accept comp argument, tries ``get_tau_L()`` without argument 3. Falls back to computing tau_L = C_m / g_L from compartment values Handles both uppercase (``get_tau_L``, ``get_C_m``, ``get_g_L``) and lowercase (``get_tau_l``, ``get_c_m``, ``get_g_l``) method naming conventions. """ fn = getattr(target, 'get_tau_L', None) if fn is None: fn = getattr(target, 'get_tau_l', None) if fn is not None and callable(fn): try: return float(fn(int(comp))) except TypeError: return float(fn()) c_m = cls._get_compartment_value(target, ['get_C_m', 'get_c_m'], comp=comp, field='C_m') g_l = cls._get_compartment_value(target, ['get_g_L', 'get_g_l'], comp=comp, field='g_L') return float(c_m / g_l) @staticmethod def _get_compartment_value(target: Any, names: list[str], comp: int, field: str) -> float: r"""Retrieve compartment-specific parameter from target neuron. Parameters ---------- target : Any Postsynaptic neuron object. names : list[str] List of method names to try in order (e.g., ``['get_C_m', 'get_c_m']``). comp : int Compartment index. field : str Field name for error message (e.g., ``'C_m'``, ``'g_L'``). Returns ------- float Requested parameter value. Raises ------ AttributeError If target does not provide any of the specified methods. Notes ----- Tries each method name in order. For each callable method found, first attempts to call with compartment index ``comp``, then falls back to calling without arguments if TypeError is raised (for neurons without multi-compartment support). """ for name in names: fn = getattr(target, name, None) if fn is None or not callable(fn): continue try: return float(fn(int(comp))) except TypeError: return float(fn()) raise AttributeError( f'Target must provide {"/".join(names)}(comp) for urbanczik_synapse ({field}).' ) @staticmethod def _extract_history_entry(entry: Any) -> tuple[float, float]: r"""Parse history entry into (time, delta_w) tuple. Parameters ---------- entry : Any History entry in one of the supported formats: - Dict with keys ``'t_'``/``'dw_'`` or ``'t'``/``'dw'`` (or variants) - 2-tuple or list ``[time, delta_w]`` - Object with attributes ``t_``/``dw_`` (NEST-style) or ``t``/``dw`` Returns ------- tuple[float, float] Extracted ``(time_ms, delta_w)`` pair. Raises ------ ValueError If entry does not provide both time and delta_w values in any recognized format. Notes ----- Supported key/attribute name variants for time: ``'t_'``, ``'t'``, ``'time_ms'``, ``'time'`` Supported key/attribute name variants for delta_w: ``'dw_'``, ``'dw'``, ``'delta_w'``, ``'weight_change'`` Prioritizes NEST-style naming (``t_``, ``dw_``) over simpler alternatives. """ t = None dw = None if isinstance(entry, dict): t = entry.get('t_', entry.get('t', entry.get('time_ms', entry.get('time', None)))) dw = entry.get('dw_', entry.get('dw', entry.get('delta_w', entry.get('weight_change', None)))) elif isinstance(entry, (tuple, list)) and len(entry) >= 2: t, dw = entry[0], entry[1] else: t = getattr(entry, 't_', getattr(entry, 't', None)) dw = getattr(entry, 'dw_', getattr(entry, 'dw', None)) if t is None or dw is None: raise ValueError('Each Urbanczik history entry must provide both time and dw values.') return float(t), float(dw) @staticmethod def _to_float_scalar(value: ArrayLike, name: str) -> float: r"""Convert value to finite float scalar with validation. Parameters ---------- value : ArrayLike Input value (scalar, array, or saiunit Quantity). name : str Parameter name for error messages. Returns ------- float Validated finite float scalar. Raises ------ ValueError If value is not scalar, non-finite (NaN/infinity), or cannot be converted to float. Notes ----- Handles saiunit Quantity objects by extracting mantissa. Flattens arrays to check for single-element constraint. """ dftype = brainstate.environ.dftype() if isinstance(value, u.Quantity): value = u.get_mantissa(value) arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be scalar.') v = float(arr[0]) if not np.isfinite(v): raise ValueError(f'{name} must be finite.') return v @staticmethod def _to_int_scalar(value: ArrayLike, name: str) -> int: r"""Convert value to integer scalar with validation. Parameters ---------- value : ArrayLike Input value (scalar, array, or saiunit Quantity). name : str Parameter name for error messages. Returns ------- int Validated integer scalar. Raises ------ ValueError If value is not scalar, non-finite, not sufficiently close to integer (tolerance 1e-12), or cannot be converted. Notes ----- Converts to float first, validates finiteness, then rounds and checks integer constraint with 1e-12 absolute tolerance. Handles saiunit Quantity objects. """ dftype = brainstate.environ.dftype() if isinstance(value, u.Quantity): value = u.get_mantissa(value) arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be scalar.') v = float(arr[0]) if not np.isfinite(v): raise ValueError(f'{name} must be finite.') iv = int(round(v)) if abs(v - iv) > 1e-12: raise ValueError(f'{name} must be an integer value.') return iv @classmethod def _validate_positive_delay(cls, value: ArrayLike) -> float: r"""Validate dendritic delay is positive. Parameters ---------- value : ArrayLike Delay value to validate. Returns ------- float Validated delay in milliseconds. Raises ------ ValueError If delay is not positive (must be > 0). """ d = cls._to_float_scalar(value, name='delay') if d <= 0.0: raise ValueError('delay must be > 0.') return d @classmethod def _validate_delay_steps(cls, value: ArrayLike) -> int: r"""Validate event delivery delay steps. Parameters ---------- value : ArrayLike Delay steps value to validate. Returns ------- int Validated delay steps (must be >= 1). Raises ------ ValueError If delay_steps is less than 1. """ d = cls._to_int_scalar(value, name='delay_steps') if d < 1: raise ValueError('delay_steps must be >= 1.') return d @classmethod def _validate_multiplicity(cls, value: ArrayLike) -> float: r"""Validate spike event multiplicity is non-negative. Parameters ---------- value : ArrayLike Multiplicity value to validate. Returns ------- float Validated multiplicity (must be >= 0). Raises ------ ValueError If multiplicity is negative. Notes ----- Multiplicity is validated but not used in plasticity computation (NEST compatibility). """ m = cls._to_float_scalar(value, name='multiplicity') if m < 0.0: raise ValueError('multiplicity must be >= 0.') return m