Source code for brainpy_state._nest.aeif_psc_delta_clopath

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, validate_aeif_overflow, AdaptiveRungeKuttaStep

__all__ = [
    'aeif_psc_delta_clopath',
]


class aeif_psc_delta_clopath(NESTNeuron):
    r"""Adaptive exponential integrate-and-fire neuron with delta-shaped synaptic input and Clopath voltage traces.

    This model extends the standard adaptive exponential integrate-and-fire (AdEx) neuron with additional
    state variables required for voltage-based Clopath plasticity. It implements delta-function postsynaptic
    currents (instantaneous voltage jumps), spike afterpotential dynamics, adaptive threshold, post-spike
    voltage clamping, and three low-pass filtered voltage traces (``u_bar_plus``, ``u_bar_minus``, ``u_bar_bar``)
    used by the Clopath learning rule.

    **1. Membrane and Adaptation Dynamics**

    The subthreshold membrane potential evolves according to:

    .. math::

       C_m \frac{dV}{dt} = -g_L (V - E_L) + g_L \Delta_T \exp\!\left(\frac{V - V_{th}}{\Delta_T}\right)
                           - w + z + I_e + I_{stim},

    where :math:`V` is the membrane potential, :math:`w` is the adaptation current, :math:`z` is the spike
    afterpotential current, :math:`V_{th}` is the adaptive threshold, :math:`I_e` is constant external current,
    and :math:`I_{stim}` is the one-step delayed synaptic input. The exponential term provides the spike
    upstroke when :math:`\Delta_T > 0`.

    Three auxiliary currents evolve as:

    .. math::

       \tau_w \frac{dw}{dt} &= a (V - E_L) - w, \\
       \tau_z \frac{dz}{dt} &= -z, \\
       \tau_{V_{th}} \frac{dV_{th}}{dt} &= -(V_{th} - V_{th,rest}).

    The adaptation current :math:`w` provides subthreshold coupling and spike-frequency adaptation.
    The afterpotential :math:`z` creates a depolarizing transient following each spike.
    The adaptive threshold :math:`V_{th}` relaxes toward :math:`V_{th,rest}` between spikes and jumps
    to :math:`V_{th,max}` upon spike emission.

    **2. Clopath Low-Pass Voltage Traces**

    Three filtered voltage variables are maintained for plasticity:

    .. math::

       \tau_{u+} \frac{du_{bar+}}{dt} &= -u_{bar+} + V, \\
       \tau_{u-} \frac{du_{bar-}}{dt} &= -u_{bar-} + V, \\
       \tau_{u\bar{}} \frac{du_{bar\bar{}}}{dt} &= -u_{bar\bar{}} + u_{bar-}.

    These traces are first-order low-pass filters of the membrane voltage with different time constants.
    ``u_bar_plus`` and ``u_bar_minus`` filter :math:`V` directly; ``u_bar_bar`` is a second-order filter
    (filters ``u_bar_minus``). Delayed versions (delayed by ``delay_u_bars``) are stored in ring buffers
    for use by Clopath synaptic plasticity rules.

    **3. Delta-Function Synaptic Input**

    Incoming synaptic spikes cause instantaneous voltage jumps:

    .. math::

       V \leftarrow V + \sum_k J_k \delta(t - t_k^{\mathrm{spike}}),

    where :math:`J_k` is the synaptic weight from presynaptic neuron :math:`k`. Delta inputs are summed
    from the ``delta_inputs`` dictionary and applied at the beginning of each accepted RKF45 substep
    (but only when the neuron is neither refractory nor clamped).

    **4. Spike Detection and Reset**

    Spike detection threshold depends on :math:`\Delta_T`:

    - If :math:`\Delta_T > 0`: threshold is ``V_peak`` (exponential blowup detector).
    - If :math:`\Delta_T = 0`: threshold is the dynamic ``V_th`` (standard IF threshold).

    Upon threshold crossing, the following spike-triggered updates occur:

    .. math::

       V &\leftarrow V_{\mathrm{clamp}}, \\
       w &\leftarrow w + b, \\
       z &\leftarrow I_{sp}, \\
       V_{th} &\leftarrow V_{th,max}, \\
       \text{clamp\_step\_count} &\leftarrow \lceil t_{\mathrm{clamp}} / dt \rceil + 1.

    **5. Post-Spike Clamping and Refractory Period**

    The model implements a two-stage reset:

    1. **Clamping stage** (duration ``t_clamp``): voltage is held at ``V_clamp``, and adaptation dynamics
       are frozen (``dw/dt = 0``). At the end of clamping (when ``clamp_step_count`` reaches 1 during
       substep integration), voltage is reset to ``V_reset`` and the refractory period begins.

    2. **Refractory stage** (duration ``t_ref``): voltage is clamped to ``V_reset``, but adaptation
       dynamics continue (``dw/dt != 0``). Spike detection is disabled during both clamping and refractory.

    This two-stage mechanism reproduces NEST's spike handling order and allows modeling of realistic
    action potential waveforms with controlled overshoot.

    **6. Numerical Integration**

    The continuous-time dynamics are integrated using an adaptive Runge-Kutta-Fehlberg 4(5) solver (RKF45)
    with local error control. The integrator maintains a persistent step size (``integration_step``) that
    adapts based on local truncation error estimates. During refractory or clamping, the effective voltage
    used in the right-hand side is replaced with ``V_reset`` or ``V_clamp``, but state integration continues.

    Parameters
    ----------
    in_size : int or tuple of int
        Population shape. Scalar for 1D populations, tuple for multi-dimensional arrays.
    V_peak : ArrayLike, default: 33.0 * u.mV
        Spike detection threshold when ``Delta_T > 0``. Must satisfy ``V_peak > V_th_rest``.
        Shape: scalar or broadcastable to ``in_size``.
    V_reset : ArrayLike, default: -60.0 * u.mV
        Reset potential after clamping ends. Must satisfy ``V_reset < V_peak``.
        Shape: scalar or broadcastable to ``in_size``.
    t_ref : ArrayLike, default: 0.0 * u.ms
        Absolute refractory period duration (non-negative). Default of 0 ms matches NEST defaults.
        Shape: scalar or broadcastable to ``in_size``.
    g_L : ArrayLike, default: 30.0 * u.nS
        Leak conductance (must be positive). Shape: scalar or broadcastable to ``in_size``.
    C_m : ArrayLike, default: 281.0 * u.pF
        Membrane capacitance (must be positive). Shape: scalar or broadcastable to ``in_size``.
    E_L : ArrayLike, default: -70.6 * u.mV
        Leak reversal potential. Shape: scalar or broadcastable to ``in_size``.
    Delta_T : ArrayLike, default: 2.0 * u.mV
        Exponential slope factor (non-negative). Set to 0 for non-exponential IF model.
        Shape: scalar or broadcastable to ``in_size``.
    tau_w : ArrayLike, default: 144.0 * u.ms
        Adaptation current time constant (must be positive). Shape: scalar or broadcastable to ``in_size``.
    tau_z : ArrayLike, default: 40.0 * u.ms
        Spike afterpotential time constant (must be positive). Shape: scalar or broadcastable to ``in_size``.
    tau_V_th : ArrayLike, default: 50.0 * u.ms
        Adaptive threshold time constant (must be positive). Shape: scalar or broadcastable to ``in_size``.
    V_th_max : ArrayLike, default: 30.4 * u.mV
        Threshold value immediately after spike. Must satisfy ``V_th_max >= V_th_rest``.
        Shape: scalar or broadcastable to ``in_size``.
    V_th_rest : ArrayLike, default: -50.4 * u.mV
        Resting threshold value (asymptotic value between spikes). Must satisfy ``V_th_rest <= V_peak``.
        Shape: scalar or broadcastable to ``in_size``.
    tau_u_bar_plus : ArrayLike, default: 7.0 * u.ms
        Time constant for ``u_bar_plus`` trace (must be positive). Shape: scalar or broadcastable to ``in_size``.
    tau_u_bar_minus : ArrayLike, default: 10.0 * u.ms
        Time constant for ``u_bar_minus`` trace (must be positive). Shape: scalar or broadcastable to ``in_size``.
    tau_u_bar_bar : ArrayLike, default: 500.0 * u.ms
        Time constant for ``u_bar_bar`` trace (must be positive). Shape: scalar or broadcastable to ``in_size``.
    a : ArrayLike, default: 4.0 * u.nS
        Subthreshold adaptation coupling strength. Shape: scalar or broadcastable to ``in_size``.
    b : ArrayLike, default: 80.5 * u.pA
        Spike-triggered adaptation increment. Shape: scalar or broadcastable to ``in_size``.
    I_sp : ArrayLike, default: 400.0 * u.pA
        Spike afterpotential current reset value (sets ``z`` on spike). Shape: scalar or broadcastable to ``in_size``.
    I_e : ArrayLike, default: 0.0 * u.pA
        Constant external current. Shape: scalar or broadcastable to ``in_size``.
    A_LTD : ArrayLike, default: 1.4e-4
        Clopath depression amplitude (dimensionless). Used in delayed-buffer bookkeeping for compatibility.
        Shape: scalar or broadcastable to ``in_size``.
    A_LTP : ArrayLike, default: 8.0e-5
        Clopath potentiation amplitude (dimensionless). Used in delayed-buffer bookkeeping for compatibility.
        Shape: scalar or broadcastable to ``in_size``.
    theta_plus : ArrayLike, default: -45.3 * u.mV
        Clopath potentiation voltage threshold. Shape: scalar or broadcastable to ``in_size``.
    theta_minus : ArrayLike, default: -70.6 * u.mV
        Clopath depression voltage threshold. Shape: scalar or broadcastable to ``in_size``.
    A_LTD_const : bool, default: True
        If True, LTD amplitude is constant. If False, LTD scales with ``u_bar_bar**2 / u_ref_squared`` (homeostatic).
    delay_u_bars : ArrayLike, default: 5.0 * u.ms
        Delay for Clopath u-bar traces (ring buffer delay). Rounded to nearest integer multiple of ``dt``.
        Shape: scalar or broadcastable to ``in_size``.
    u_ref_squared : ArrayLike, default: 60.0
        Clopath LTD homeostatic reference (dimensionless, must be positive). Only used when ``A_LTD_const=False``.
        Shape: scalar or broadcastable to ``in_size``.
    gsl_error_tol : ArrayLike, default: 1e-6
        RKF45 local error tolerance (must be positive). Smaller values increase accuracy and decrease step size.
        Shape: scalar or broadcastable to ``in_size``.
    t_clamp : ArrayLike, default: 2.0 * u.ms
        Spike clamping duration (non-negative). Shape: scalar or broadcastable to ``in_size``.
    V_clamp : ArrayLike, default: 33.0 * u.mV
        Clamped voltage immediately after spike. Shape: scalar or broadcastable to ``in_size``.
    V_initializer : Callable, default: braintools.init.Constant(-70.6 * u.mV)
        Initializer for membrane potential. Must return values with voltage units.
    w_initializer : Callable, default: braintools.init.Constant(0.0 * u.pA)
        Initializer for adaptation current. Must return values with current units.
    z_initializer : Callable, default: braintools.init.Constant(0.0 * u.pA)
        Initializer for spike afterpotential current. Must return values with current units.
    V_th_initializer : Callable, default: braintools.init.Constant(-50.4 * u.mV)
        Initializer for adaptive threshold. Must return values with voltage units.
    u_bar_plus_initializer : Callable, default: braintools.init.Constant(-70.6 * u.mV)
        Initializer for ``u_bar_plus`` trace. Must return values with voltage units.
    u_bar_minus_initializer : Callable, default: braintools.init.Constant(-70.6 * u.mV)
        Initializer for ``u_bar_minus`` trace. Must return values with voltage units.
    u_bar_bar_initializer : Callable, default: braintools.init.Constant(-70.6 * u.mV)
        Initializer for ``u_bar_bar`` trace. Must return values with voltage units.
    spk_fun : Callable, default: braintools.surrogate.ReluGrad()
        Surrogate gradient function for differentiable spike generation during training.
    spk_reset : str, default: 'hard'
        Spike reset mode. 'hard' (stop_gradient) matches NEST behavior; 'soft' (V -= V_th) preserves gradients.
    ref_var : bool, default: False
        If True, expose ``refractory`` state variable indicating whether neuron is refractory or clamped.
    name : str, optional
        Name for this neuron instance.

    Parameter Mapping
    -----------------
    The table below maps BrainPy parameters to their mathematical symbols and NEST equivalents:

    ========================== ================== ================================ ================================================================
    **Parameter**              **Default**        **Math Symbol**                  **Description**
    ========================== ================== ================================ ================================================================
    ``in_size``                (required)         —                                Population shape
    ``V_peak``                 33 mV              :math:`V_\mathrm{peak}`          Spike detection threshold for :math:`\Delta_T > 0`
    ``V_reset``                -60 mV             :math:`V_\mathrm{reset}`         Reset potential
    ``t_ref``                  0 ms               :math:`t_\mathrm{ref}`           Absolute refractory duration
    ``g_L``                    30 nS              :math:`g_\mathrm{L}`             Leak conductance
    ``C_m``                    281 pF             :math:`C_\mathrm{m}`             Membrane capacitance
    ``E_L``                    -70.6 mV           :math:`E_\mathrm{L}`             Leak reversal potential
    ``Delta_T``                2 mV               :math:`\Delta_T`                 Exponential slope factor
    ``tau_w``                  144 ms             :math:`\tau_w`                   Adaptation time constant
    ``tau_z``                  40 ms              :math:`\tau_z`                   Spike afterpotential time constant
    ``tau_V_th``               50 ms              :math:`\tau_{V_{th}}`            Adaptive threshold time constant
    ``V_th_max``               30.4 mV            :math:`V_{th,\mathrm{max}}`      Threshold value immediately after spike
    ``V_th_rest``              -50.4 mV           :math:`V_{th,\mathrm{rest}}`     Resting threshold value
    ``tau_u_bar_plus``         7 ms               :math:`\tau_{u+}`                Time constant of ``u_bar_plus``
    ``tau_u_bar_minus``        10 ms              :math:`\tau_{u-}`                Time constant of ``u_bar_minus``
    ``tau_u_bar_bar``          500 ms             :math:`\tau_{u\bar{}}`           Time constant of ``u_bar_bar``
    ``a``                      4 nS               :math:`a`                        Subthreshold adaptation strength
    ``b``                      80.5 pA            :math:`b`                        Spike-triggered adaptation increment
    ``I_sp``                   400 pA             :math:`I_{sp}`                   Spike afterpotential current reset value
    ``I_e``                    0 pA               :math:`I_\mathrm{e}`             Constant external current
    ``A_LTD``                  1.4e-4             :math:`A_\mathrm{LTD}`           Clopath depression amplitude
    ``A_LTP``                  8.0e-5             :math:`A_\mathrm{LTP}`           Clopath potentiation amplitude
    ``theta_plus``             -45.3 mV           :math:`\theta_+`                 Clopath potentiation threshold
    ``theta_minus``            -70.6 mV           :math:`\theta_-`                 Clopath depression threshold
    ``A_LTD_const``            ``True``           —                                If False, homeostatic LTD scaling
    ``delay_u_bars``           5 ms               —                                Delay for Clopath u-bar buffers
    ``u_ref_squared``          60                 :math:`u_\mathrm{ref}^2`         Clopath LTD homeostatic reference
    ``gsl_error_tol``          1e-6               —                                RKF45 local error tolerance
    ``t_clamp``                2 ms               :math:`t_\mathrm{clamp}`         Spike clamping duration
    ``V_clamp``                33 mV              :math:`V_\mathrm{clamp}`         Clamped voltage after spike
    ``V_initializer``          Constant(E_L)      —                                Membrane voltage initializer
    ``w_initializer``          Constant(0 pA)     —                                Adaptation current initializer
    ``z_initializer``          Constant(0 pA)     —                                Spike afterpotential initializer
    ``V_th_initializer``       Constant(-50.4 mV) —                                Adaptive threshold initializer
    ``u_bar_plus_initializer`` Constant(-70.6 mV) —                                ``u_bar_plus`` initializer
    `u_bar_minus_initializer`  Constant(-70.6 mV) —                                ``u_bar_minus`` initializer
    ``u_bar_bar_initializer``  Constant(-70.6 mV) —                                ``u_bar_bar`` initializer
    ``spk_fun``                ReluGrad()         —                                Surrogate spike function
    ``spk_reset``              ``'hard'``         —                                Reset mode (``'hard'`` or ``'soft'``)
    ``ref_var``                ``False``          —                                If True, expose refractory indicator
    ========================== ================== ================================ ================================================================

    Attributes
    ----------
    V : brainstate.HiddenState
        Membrane potential (mV). Shape: ``(*in_size,)``.
    w : brainstate.HiddenState
        Adaptation current (pA). Shape: ``(*in_size,)``.
    z : brainstate.HiddenState
        Spike afterpotential current (pA). Shape: ``(*in_size,)``.
    V_th : brainstate.HiddenState
        Adaptive threshold (mV). Shape: ``(*in_size,)``.
    u_bar_plus : brainstate.HiddenState
        Clopath low-pass filtered voltage trace (mV). Shape: ``(*in_size,)``.
    u_bar_minus : brainstate.HiddenState
        Clopath low-pass filtered voltage trace (mV). Shape: ``(*in_size,)``.
    u_bar_bar : brainstate.HiddenState
        Clopath second-order filtered voltage trace (mV). Shape: ``(*in_size,)``.
    refractory_step_count : brainstate.ShortTermState
        Remaining refractory time steps (int32). Shape: ``(*in_size,)``.
    clamp_step_count : brainstate.ShortTermState
        Remaining clamping time steps (int32). Shape: ``(*in_size,)``.
    integration_step : brainstate.ShortTermState
        Current RKF45 adaptive step size (ms). Shape: ``(*in_size,)``.
    I_stim : brainstate.ShortTermState
        One-step delayed synaptic current (pA). Shape: ``(*in_size,)``.
    delayed_u_bar_plus_buffer : brainstate.ShortTermState
        Ring buffer for delayed ``u_bar_plus`` (mV). Shape: ``(delay_steps, *in_size)``.
    delayed_u_bar_minus_buffer : brainstate.ShortTermState
        Ring buffer for delayed ``u_bar_minus`` (mV). Shape: ``(delay_steps, *in_size)``.
    delayed_u_bars_idx : brainstate.ShortTermState
        Current ring buffer write index (int32). Scalar.
    delayed_u_bars_steps : brainstate.ShortTermState
        Total ring buffer size (int32). Scalar.
    last_spike_time : brainstate.ShortTermState
        Last spike time (ms). Shape: ``(*in_size,)``.
    refractory : brainstate.ShortTermState, optional
        Boolean indicator: True if neuron is refractory or clamped. Only present if ``ref_var=True``.
        Shape: ``(*in_size,)``.

    Raises
    ------
    ValueError

        - If ``V_reset >= V_peak``.
        - If ``Delta_T < 0``.
        - If ``V_th_max < V_th_rest`` or ``V_peak < V_th_rest``.
        - If ``C_m <= 0``, ``t_ref < 0``, ``t_clamp < 0``, or any time constant <= 0.
        - If ``u_ref_squared <= 0`` or ``gsl_error_tol <= 0``.
        - If ``(V_peak - V_th_rest) / Delta_T`` exceeds overflow limit (when ``Delta_T > 0``).
        - If ``delay_u_bars`` maps to fewer than 1 delay buffer entry.
        - If ``delay_u_bars`` is spatially heterogeneous (delay steps must be uniform).
        - If numerical instability is detected during integration (voltage or adaptation out of bounds).

    Notes
    -----
    **Implementation Details:**

    - **RKF45 integration:** Uses adaptive-step Runge-Kutta-Fehlberg 4(5) with error control. Step size
      is persisted across time steps to improve stability. Minimum step size is 1e-8 ms; maximum iteration
      count is 100000 per ``dt`` to prevent infinite loops.

    - **Refractory/clamping precedence:** During integration, if ``clamp_step_count > 0``, voltage is clamped
      to ``V_clamp`` and adaptation dynamics freeze. If ``refractory_step_count > 0`` (and not clamped),
      voltage is clamped to ``V_reset`` but adaptation continues. Both conditions disable spike detection.

    - **Delta input timing:** Delta voltage jumps are applied at the start of each accepted substep, but
      only when the neuron is neither refractory nor clamped. This matches NEST's per-substep spike delivery.

    - **Spike timing convention:** ``last_spike_time`` is set to ``t + dt`` upon spike emission (end of
      current time step), matching NEST's convention.

    - **Clopath buffer bookkeeping:** This implementation maintains delayed ``u_bar_plus`` and ``u_bar_minus``
      buffers even without a dedicated Clopath synapse model, ensuring state-level compatibility with NEST
      for future plasticity extensions. The delayed traces are updated at the end of each ``update()`` call.

    - **Overflow protection:** The exponential term is guarded against overflow when ``Delta_T > 0``. If
      ``(V_peak - V_th_rest) / Delta_T`` would cause ``exp(...)`` to exceed ``max(float64) / 1e20``, an
      error is raised during initialization.

    **Usage:**

    This model is designed for voltage-based plasticity studies and detailed spike waveform modeling.
    Use ``Delta_T > 0`` for exponential IF dynamics (rapid spike upstroke) or ``Delta_T = 0`` for standard
    IF with dynamic threshold. The ``t_clamp`` and ``V_clamp`` parameters control the spike overshoot and
    allow modeling realistic action potential shapes. For basic AdEx simulations without Clopath plasticity,
    consider using the simpler ``aeif_psc_delta`` or ``aeif_psc_exp`` models (if available).

    See Also
    --------
    aeif_psc_delta : Simplified AdEx without Clopath traces or clamping.
    aeif_psc_exp : AdEx with exponential postsynaptic currents.
    clopath_synapse : Voltage-based STDP synapse (NEST reference).

    References
    ----------
    .. [1] Clopath C, Büsing L, Vasilaki E, Gerstner W (2010). Connectivity reflects coding: a model of
           voltage-based STDP with homeostasis. *Nature Neuroscience*, 13(3):344-352.
           DOI: https://doi.org/10.1038/nn.2479
    .. [2] Brette R, Gerstner W (2005). Adaptive exponential integrate-and-fire model as an effective
           description of neuronal activity. *Journal of Neurophysiology*, 94:3637-3642.
           DOI: https://doi.org/10.1152/jn.00686.2005
    .. [3] NEST Simulator documentation: ``aeif_psc_delta_clopath`` model.
           https://nest-simulator.readthedocs.io/
    .. [4] NEST source code: ``models/aeif_psc_delta_clopath.h`` and ``models/aeif_psc_delta_clopath.cpp``.

    Examples
    --------
    Simulate a single neuron with step current input:

    .. code-block:: python

       >>> import brainpy.state as bst
       >>> import brainstate as bs
       >>> import saiunit as u
       >>> import matplotlib.pyplot as plt
       >>>
       >>> # Create neuron population
       >>> neuron = bst.aeif_psc_delta_clopath(in_size=1, I_e=300*u.pA)
       >>>
       >>> # Simulate for 100 ms
       >>> with bs.environ.context(dt=0.1*u.ms):
       ...     neuron.init_state()
       ...     times, voltages = [], []
       ...     for t in range(1000):
       ...         spike = neuron.update()
       ...         times.append(t * 0.1)
       ...         voltages.append(float(neuron.V.value / u.mV))
       >>>
       >>> # Plot membrane potential
       >>> plt.plot(times, voltages)
       >>> plt.xlabel('Time (ms)')
       >>> plt.ylabel('Voltage (mV)')
       >>> plt.show()

    Network simulation with delta-function synaptic connections:

    .. code-block:: python

       >>> import brainpy.state as bst
       >>> import brainstate as bs
       >>> import saiunit as u
       >>>
       >>> # Create excitatory and inhibitory populations
       >>> exc = bst.aeif_psc_delta_clopath(in_size=100, I_e=200*u.pA)
       >>> inh = bst.aeif_psc_delta_clopath(in_size=25, I_e=150*u.pA)
       >>>
       >>> # Create delta-function projection (instantaneous voltage jump)
       >>> # Note: Requires appropriate projection class that adds delta inputs
       >>> # exc_to_inh = bst.DeltaProj(exc, inh, weight=0.5*u.mV, prob=0.1)
       >>>
       >>> # Simulate network
       >>> with bs.environ.context(dt=0.1*u.ms):
       ...     exc.init_state()
       ...     inh.init_state()
       ...     for t in range(10000):  # 1 second
       ...         exc_spikes = exc.update()
       ...         inh_spikes = inh.update()
    """

    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        V_peak: ArrayLike = 33.0 * u.mV,
        V_reset: ArrayLike = -60.0 * u.mV,
        t_ref: ArrayLike = 0.0 * u.ms,
        g_L: ArrayLike = 30.0 * u.nS,
        C_m: ArrayLike = 281.0 * u.pF,
        E_L: ArrayLike = -70.6 * u.mV,
        Delta_T: ArrayLike = 2.0 * u.mV,
        tau_w: ArrayLike = 144.0 * u.ms,
        tau_z: ArrayLike = 40.0 * u.ms,
        tau_V_th: ArrayLike = 50.0 * u.ms,
        V_th_max: ArrayLike = 30.4 * u.mV,
        V_th_rest: ArrayLike = -50.4 * u.mV,
        tau_u_bar_plus: ArrayLike = 7.0 * u.ms,
        tau_u_bar_minus: ArrayLike = 10.0 * u.ms,
        tau_u_bar_bar: ArrayLike = 500.0 * u.ms,
        a: ArrayLike = 4.0 * u.nS,
        b: ArrayLike = 80.5 * u.pA,
        I_sp: ArrayLike = 400.0 * u.pA,
        I_e: ArrayLike = 0.0 * u.pA,
        A_LTD: ArrayLike = 14.0e-5,
        A_LTP: ArrayLike = 8.0e-5,
        theta_plus: ArrayLike = -45.3 * u.mV,
        theta_minus: ArrayLike = -70.6 * u.mV,
        A_LTD_const: bool = True,
        delay_u_bars: ArrayLike = 5.0 * u.ms,
        u_ref_squared: ArrayLike = 60.0,
        gsl_error_tol: ArrayLike = 1e-6,
        t_clamp: ArrayLike = 2.0 * u.ms,
        V_clamp: ArrayLike = 33.0 * u.mV,
        V_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
        w_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
        z_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
        V_th_initializer: Callable = braintools.init.Constant(-50.4 * u.mV),
        u_bar_plus_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
        u_bar_minus_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
        u_bar_bar_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.V_peak = braintools.init.param(V_peak, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.Delta_T = braintools.init.param(Delta_T, self.varshape)
        self.tau_w = braintools.init.param(tau_w, self.varshape)
        self.tau_z = braintools.init.param(tau_z, self.varshape)
        self.tau_V_th = braintools.init.param(tau_V_th, self.varshape)
        self.V_th_max = braintools.init.param(V_th_max, self.varshape)
        self.V_th_rest = braintools.init.param(V_th_rest, self.varshape)
        self.tau_u_bar_plus = braintools.init.param(tau_u_bar_plus, self.varshape)
        self.tau_u_bar_minus = braintools.init.param(tau_u_bar_minus, self.varshape)
        self.tau_u_bar_bar = braintools.init.param(tau_u_bar_bar, self.varshape)
        self.a = braintools.init.param(a, self.varshape)
        self.b = braintools.init.param(b, self.varshape)
        self.I_sp = braintools.init.param(I_sp, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)

        # Clopath-related parameters kept for source-level compatibility.
        self.A_LTD = braintools.init.param(A_LTD, self.varshape)
        self.A_LTP = braintools.init.param(A_LTP, self.varshape)
        self.theta_plus = braintools.init.param(theta_plus, self.varshape)
        self.theta_minus = braintools.init.param(theta_minus, self.varshape)
        self.A_LTD_const = bool(A_LTD_const)
        self.delay_u_bars = braintools.init.param(delay_u_bars, self.varshape)
        self.u_ref_squared = braintools.init.param(u_ref_squared, self.varshape)

        self.gsl_error_tol = gsl_error_tol
        self.t_clamp = braintools.init.param(t_clamp, self.varshape)
        self.V_clamp = braintools.init.param(V_clamp, self.varshape)

        self.V_initializer = V_initializer
        self.w_initializer = w_initializer
        self.z_initializer = z_initializer
        self.V_th_initializer = V_th_initializer
        self.u_bar_plus_initializer = u_bar_plus_initializer
        self.u_bar_minus_initializer = u_bar_minus_initializer
        self.u_bar_bar_initializer = u_bar_bar_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
        self.clamp_count = u.math.asarray(u.math.ceil(self.t_clamp / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated,
            or if the exponential term can overflow at spike time for the
            configured ``V_peak``, ``V_th_rest``, and ``Delta_T``.
        """
        v_reset = self.V_reset
        v_peak = self.V_peak
        v_th_rest = self.V_th_rest
        v_th_max = self.V_th_max
        delta_t = self.Delta_T / u.mV

        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (v_reset, v_peak, v_th_rest, v_th_max, delta_t)):
            return

        if np.any(v_reset >= v_peak):
            raise ValueError('Ensure that V_reset < V_peak .')
        if np.any(delta_t < 0.0):
            raise ValueError('Delta_T must be greater than or equal to zero.')
        if np.any(v_th_max < v_th_rest):
            raise ValueError('V_th_max >= V_th_rest required.')
        if np.any(v_peak < v_th_rest):
            raise ValueError('V_peak >= V_th_rest required.')

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Ensure that C_m > 0')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.t_clamp < 0.0 * u.ms):
            raise ValueError('Ensure that t_clamp >= 0')

        if np.any(self.tau_w <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_z <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_V_th <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_u_bar_plus <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_u_bar_minus <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_u_bar_bar <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')

        if np.any(self.u_ref_squared <= 0.0):
            raise ValueError('Ensure that u_ref_squared > 0')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

        # Mirror NEST overflow guard for exponential term at spike time.
        validate_aeif_overflow(v_peak, v_th_rest, delta_t)

[docs] def init_state(self, **kwargs): r"""Initialize all state variables. Allocates and initializes all neuron state variables using the configured initializers. This includes membrane dynamics states (V, w, z, V_th), Clopath voltage traces (u_bar_plus, u_bar_minus, u_bar_bar), refractory/clamping counters, RKF45 integration state, and delayed-buffer bookkeeping for Clopath plasticity. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- - ``last_spike_time`` is initialized to -1e7 ms (far in the past) to indicate no prior spike. - ``refractory_step_count`` and ``clamp_step_count`` are initialized to 0 (not refractory/clamped). - ``integration_step`` is initialized to the current simulation time step (``dt``). - Clopath delay buffers are allocated with size ``ceil(delay_u_bars / dt) + 1``. - If ``ref_var=True``, an additional ``refractory`` boolean state is created. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. See Also -------- reset_state : Reset existing states to initial values. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() V = braintools.init.param(self.V_initializer, self.varshape) w = braintools.init.param(self.w_initializer, self.varshape) z = braintools.init.param(self.z_initializer, self.varshape) v_th = braintools.init.param(self.V_th_initializer, self.varshape) u_plus = braintools.init.param(self.u_bar_plus_initializer, self.varshape) u_minus = braintools.init.param(self.u_bar_minus_initializer, self.varshape) u_bar = braintools.init.param(self.u_bar_bar_initializer, self.varshape) self.V = brainstate.HiddenState(V) self.w = brainstate.HiddenState(w) self.z = brainstate.HiddenState(z) self.V_th = brainstate.HiddenState(v_th) self.u_bar_plus = brainstate.HiddenState(u_plus) self.u_bar_minus = brainstate.HiddenState(u_minus) self.u_bar_bar = brainstate.HiddenState(u_bar) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.clamp_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) # Clopath delay buffers self._allocate_clopath_delay_buffers(self.varshape, dt) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
def _delay_u_bars_steps(self, dt_q): """Compute the number of delay buffer steps for Clopath u-bar traces.""" dt_ms = float(u.math.asarray(dt_q / u.ms)) delay_ms = float(u.math.asarray(self.delay_u_bars / u.ms)) ditype = brainstate.environ.ditype() delay_steps = int(np.rint(delay_ms / dt_ms)) + 1 if delay_steps < 1: raise ValueError('delay_u_bars must map to at least one delay-buffer entry.') return delay_steps def _allocate_clopath_delay_buffers(self, state_shape, dt_q): """Allocate ring buffers for delayed Clopath u-bar traces.""" delay_steps = self._delay_u_bars_steps(dt_q) self._delay_steps = delay_steps # Python int for JIT-safe modulo ditype = brainstate.environ.ditype() self.delayed_u_bars_steps = brainstate.ShortTermState(np.asarray(delay_steps, dtype=ditype)) self.delayed_u_bars_idx = brainstate.ShortTermState(np.asarray(0, dtype=ditype)) buf_shape = (delay_steps,) + tuple(state_shape) dftype = brainstate.environ.dftype() self.delayed_u_bar_plus_buffer = brainstate.ShortTermState(np.zeros(buf_shape, dtype=dftype)) self.delayed_u_bar_minus_buffer = brainstate.ShortTermState(np.zeros(buf_shape, dtype=dftype))
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient function. Applies the surrogate gradient function to a scaled voltage relative to the dynamic threshold. This produces a continuous approximation of discrete spikes, enabling gradient-based learning. The scaling factor ``(v_th - V_reset)`` normalizes the voltage range for the surrogate function. Parameters ---------- V : ArrayLike, optional Membrane potential (mV). If None, uses current ``self.V.value``. Shape: ``(*in_size,)``. Returns ------- spike : ArrayLike Differentiable spike signal (dimensionless, approximately in [0, 1] for most surrogate functions). Shape matches input ``V``. Notes ----- - This method is primarily used during training with surrogate gradient descent. - During inference with ``update()``, spikes are detected via hard threshold crossing (not this function). - The threshold used is the dynamic ``V_th`` (if available) or the resting ``V_th_rest`` otherwise. - The surrogate function is configured via the ``spk_fun`` parameter (default: ``ReluGrad``). See Also -------- update : Hard spike detection and state integration. """ V = self.V.value if V is None else V if hasattr(self, 'V_th'): v_th = self.V_th.value else: v_th = self.V_th_rest v_scaled = (V - v_th) / (v_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, w, z, V_th_adapt, u_plus, u_minus, u_bar -- ODE state variables. extra : DotDict Keys: spike_mask, r, clamp_r, unstable, i_stim, v_peak_detect -- mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 is_clamped = extra.clamp_r > 0 # Effective voltage: V_clamp if clamped, V_reset if refractory, else min(V, V_peak) v_eff = u.math.where( is_clamped, self.V_clamp, u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_peak)) ) delta_t_safe = u.math.where(self.Delta_T == 0.0 * u.mV, 1.0 * u.mV, self.Delta_T) exp_arg = u.math.clip((v_eff - state.V_th_adapt) / delta_t_safe, -500.0, 500.0) i_spike = self.g_L * self.Delta_T * u.math.exp(exp_arg) dV_raw = ( -self.g_L * (v_eff - self.E_L) + i_spike - state.w + state.z + self.I_e + extra.i_stim ) / self.C_m dV = u.math.where(is_refractory | is_clamped, u.math.zeros_like(dV_raw), dV_raw) # NEST sets dw/dt = 0 while clamped, but not during pure refractory. dw_raw = (self.a * (v_eff - self.E_L) - state.w) / self.tau_w dw = u.math.where(is_clamped, u.math.zeros_like(dw_raw), dw_raw) dz = -state.z / self.tau_z dV_th_adapt = -(state.V_th_adapt - self.V_th_rest) / self.tau_V_th du_plus = (-state.u_plus + v_eff) / self.tau_u_bar_plus du_minus = (-state.u_minus + v_eff) / self.tau_u_bar_minus du_bar = (-state.u_bar + state.u_minus) / self.tau_u_bar_bar return DotDict( V=dV, w=dw, z=dz, V_th_adapt=dV_th_adapt, u_plus=du_plus, u_minus=du_minus, u_bar=du_bar ) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, clamping, and refractory handling. Parameters ---------- state : DotDict Keys: V, w, z, V_th_adapt, u_plus, u_minus, u_bar -- ODE state variables. extra : DotDict Keys: spike_mask, r, clamp_r, unstable, i_stim, v_peak_detect. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ unstable = extra.unstable | jnp.any( accept & ((state.V < -1e3 * u.mV) | (state.w < -1e6 * u.pA) | (state.w > 1e6 * u.pA)) ) new_V = state.V new_w = state.w new_z = state.z new_V_th_adapt = state.V_th_adapt # Spike detection: not clamped, not refractory, voltage >= threshold spike_now = accept & (extra.clamp_r <= 0) & (extra.r <= 0) & (new_V >= extra.v_peak_detect) spike_mask = extra.spike_mask | spike_now # Spike-triggered updates new_V = u.math.where(spike_now, self.V_clamp, new_V) new_w = u.math.where(spike_now, state.w + self.b, new_w) new_z = u.math.where(spike_now, self.I_sp, new_z) new_V_th_adapt = u.math.where(spike_now, self.V_th_max, new_V_th_adapt) clamp_r = u.math.where(spike_now & (self.clamp_count > 0), self.clamp_count + 1, extra.clamp_r) # Clamp expiry: clamp_r == 1 means clamping ends this substep -> transition to refractory clamp_expiry = accept & (clamp_r == 1) new_V = u.math.where(clamp_expiry, self.V_reset, new_V) clamp_r = u.math.where(clamp_expiry, 0, clamp_r) r = u.math.where(clamp_expiry & (self.ref_count > 0), self.ref_count + 1, extra.r) # During refractory (not clamped), clamp voltage to V_reset refr_accept = accept & (r > 0) & (clamp_r <= 0) new_V = u.math.where(refr_accept, self.V_reset, new_V) new_state = DotDict({ **state, 'V': new_V, 'w': new_w, 'z': new_z, 'V_th_adapt': new_V_th_adapt }) new_extra = DotDict({ **extra, 'spike_mask': spike_mask, 'r': r, 'clamp_r': clamp_r, 'unstable': unstable }) return new_state, new_extra def _sum_delta_inputs(self): """Sum all delta (instantaneous voltage jump) inputs.""" delta_v = u.math.zeros_like(self.V.value) if self.delta_inputs is None: return delta_v for key in tuple(self.delta_inputs.keys()): out = self.delta_inputs[key] if callable(out): out = out() else: self.delta_inputs.pop(key) delta_v = delta_v + out return delta_v def _write_clopath_history(self, V_m, u_plus, u_minus, u_bar): """Update Clopath delayed ring buffers with current u-bar traces.""" ditype = brainstate.environ.ditype() idx = self.delayed_u_bars_idx.value plus_buf = jnp.asarray(self.delayed_u_bar_plus_buffer.value) minus_buf = jnp.asarray(self.delayed_u_bar_minus_buffer.value) u_plus_val = u.get_mantissa(u_plus) u_minus_val = u.get_mantissa(u_minus) plus_buf = plus_buf.at[idx].set(u_plus_val) minus_buf = minus_buf.at[idx].set(u_minus_val) new_idx = (idx + 1) % self._delay_steps self.delayed_u_bar_plus_buffer.value = plus_buf self.delayed_u_bar_minus_buffer.value = minus_buf self.delayed_u_bars_idx.value = jnp.asarray(new_idx, dtype=ditype)
[docs] def update(self, x=0.0 * u.pA): r"""Advance neuron state by one time step using adaptive RKF45 integration. Integrates the neuron dynamics over the current simulation time step ``dt`` using an adaptive Runge-Kutta-Fehlberg 4(5) solver with local error control. Handles spike detection, post-spike reset, refractory period, voltage clamping, delta-function synaptic inputs, and Clopath trace updates. Returns binary spike output for the current time step. Parameters ---------- x : ArrayLike, default: 0.0 * u.pA External input current for the current time step (pA). This is combined with synaptic currents from ``current_inputs`` dictionary. Shape: scalar or broadcastable to ``(*in_size,)``. Returns ------- spike : ArrayLike Binary spike indicator (1.0 if neuron spiked during this time step, 0.0 otherwise). Shape: ``(*in_size,)``. Raises ------ ValueError If numerical instability is detected (voltage < -1000 mV or abs(adaptation) > 1e6 pA). Notes ----- Integration is performed with an adaptive vectorized RKF45 loop, including in-loop spike/reset/adaptation events and optional multiple spikes per step. All arithmetic is unit-aware via ``saiunit.math``. See Also -------- init_state : Initialize state variables before first update. get_spike : Differentiable spike output for training. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV w = self.w.value # pA z = self.z.value # pA V_th_adapt = self.V_th.value # mV u_plus = self.u_bar_plus.value # mV u_minus = self.u_bar_minus.value # mV u_bar = self.u_bar_bar.value # mV r = self.refractory_step_count.value # int clamp_r = self.clamp_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Spike detection threshold: V_peak if Delta_T > 0, else V_th (dynamic). v_peak_detect = u.math.where(self.Delta_T > 0.0 * u.mV, self.V_peak, V_th_adapt) # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict( V=V, w=w, z=z, V_th_adapt=V_th_adapt, u_plus=u_plus, u_minus=u_minus, u_bar=u_bar ) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, clamp_r=clamp_r, unstable=jnp.array(False), i_stim=i_stim, v_peak_detect=v_peak_detect, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V = ode_state.V w = ode_state.w z = ode_state.z V_th_adapt = ode_state.V_th_adapt u_plus = ode_state.u_plus u_minus = ode_state.u_minus u_bar = ode_state.u_bar spike_mask, r, clamp_r, unstable = extra.spike_mask, extra.r, extra.clamp_r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in aeif_psc_delta_clopath dynamics.' ) # Clopath delay buffer bookkeeping. self._write_clopath_history(V, u_plus, u_minus, u_bar) # Decrement counters. clamp_r = u.math.where(clamp_r > 0, clamp_r - 1, clamp_r) r = u.math.where(r > 0, r - 1, r) # Delta inputs (applied after integration). delta_v = self._sum_delta_inputs() # Only apply delta inputs when not refractory and not clamped. can_receive = (r <= 0) & (clamp_r <= 0) V = u.math.where(can_receive, V + delta_v, V) # Write back state. self.V.value = V self.w.value = w self.z.value = z self.V_th.value = V_th_adapt self.u_bar_plus.value = u_plus self.u_bar_minus.value = u_minus self.u_bar_bar.value = u_bar self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.clamp_step_count.value = jnp.asarray(u.get_mantissa(clamp_r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( (self.refractory_step_count.value > 0) | (self.clamp_step_count.value > 0) ) return u.math.asarray(spike_mask, dtype=dftype)