Source code for brainpy_state._nest.amat2_psc_exp

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer

__all__ = [
    'amat2_psc_exp',
]


class amat2_psc_exp(NESTNeuron):
    r"""NEST-compatible ``amat2_psc_exp`` neuron model.

    Non-resetting leaky integrate-and-fire neuron with exponential postsynaptic
    currents, two-timescale adaptive threshold, and a voltage-dependent threshold
    component that tracks the low-pass filtered membrane potential derivative.

    **Model Overview**

    ``amat2_psc_exp`` extends the ``mat2_psc_exp`` model by adding a voltage-dependent
    threshold component :math:`V_{th,v}` that captures the effect of fast voltage
    fluctuations on spike threshold. This mechanism improves the model's ability to
    reproduce diverse firing patterns observed in biological neurons, including
    burst firing and spike-frequency adaptation.

    The model features:

    - **Non-resetting membrane potential**: After spike emission, the membrane
      potential continues to integrate normally without reset
    - **Exponential PSCs**: Postsynaptic currents decay exponentially with separate
      time constants for excitation and inhibition
    - **Multi-timescale adaptation**: Two independent threshold components (fast and
      slow) capture short-term and long-term adaptation
    - **Voltage-dependent threshold**: A third threshold component tracks the
      low-pass filtered derivative of membrane potential, making the threshold
      sensitive to voltage velocity
    - **Absolute refractory period**: Spike emission is blocked for a fixed duration
      after each spike

    When ``beta = 0``, this model reduces to ``mat2_psc_exp``.

    Mathematical Description
    ------------------------

    **1. Subthreshold Membrane Dynamics**

    The membrane potential evolves according to:

    .. math::

       \frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m}
       + \frac{I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}} + I_e + I_0}{C_m}

    where :math:`V_m` is the absolute membrane potential, :math:`E_L` is the resting
    potential, :math:`\tau_m` is the membrane time constant, :math:`C_m` is the
    membrane capacitance, and :math:`I_{\mathrm{syn,ex}}`, :math:`I_{\mathrm{syn,in}}`,
    :math:`I_e`, and :math:`I_0` are excitatory synaptic, inhibitory synaptic,
    constant external, and dynamic external currents, respectively.

    **2. Synaptic Current Dynamics**

    Postsynaptic currents decay exponentially:

    .. math::

       \frac{dI_{\mathrm{syn,ex}}}{dt} = -\frac{I_{\mathrm{syn,ex}}}{\tau_{\mathrm{syn,ex}}}

       \frac{dI_{\mathrm{syn,in}}}{dt} = -\frac{I_{\mathrm{syn,in}}}{\tau_{\mathrm{syn,in}}}

    Incoming spikes cause instantaneous jumps in the corresponding current by the
    synaptic weight.

    **3. Adaptive Threshold Dynamics**

    The total spike threshold is:

    .. math::

       V_{th}(t) = \omega + V_{th,1}(t) + V_{th,2}(t) + V_{th,v}(t)

    where :math:`\omega` is the resting threshold (an absolute voltage), and
    :math:`V_{th,1}`, :math:`V_{th,2}`, :math:`V_{th,v}` are adaptive components.

    The two time-dependent threshold components decay exponentially:

    .. math::

       \frac{dV_{th,1}}{dt} = -\frac{V_{th,1}}{\tau_1}
       \qquad
       \frac{dV_{th,2}}{dt} = -\frac{V_{th,2}}{\tau_2}

    On each spike emission, these components are incremented:

    .. math::

       V_{th,1} \leftarrow V_{th,1} + \alpha_1
       \qquad
       V_{th,2} \leftarrow V_{th,2} + \alpha_2

    **4. Voltage-Dependent Threshold Component**

    The voltage-dependent threshold component is defined as [3]_, Eqs. 16-17:

    .. math::

       V_{th,v}(t) = \beta \int_0^t K(s) \frac{dV_m}{dt}(t-s)\, ds

    where the kernel is:

    .. math::

       K(s) = \frac{s}{\tau_v} \exp\!\left(-\frac{s}{\tau_v}\right)

    This convolution is implemented via two auxiliary state variables
    :math:`V_{th,v}` and :math:`V_{th,dv}`, which are evolved using the exact
    integration scheme. The propagator coefficients for these variables depend
    on :math:`\beta`, :math:`\tau_v`, and all other time constants and are
    computed according to the formulas in the NEST implementation (see ``update``
    method for details).

    **5. Spike Emission and Refractory Period**

    A spike is emitted when:

    .. math::

       V_m \geq V_{th}(t) \quad \text{and} \quad t - t_{\mathrm{last\_spike}} > t_{\mathrm{ref}}

    where :math:`t_{\mathrm{ref}}` is the absolute refractory period. Upon spike
    emission:

    - The threshold components :math:`V_{th,1}` and :math:`V_{th,2}` are incremented
    - The refractory period counter is set to :math:`t_{\mathrm{ref}} / dt`
    - **The membrane potential is NOT reset** but continues to integrate normally

    **6. Numerical Integration**

    The model uses the exact integration scheme for linear ODEs [1]_, computing
    closed-form propagator matrices for one time step. This ensures numerical
    stability and accuracy for arbitrary time step sizes (subject to the constraint
    that all time constants must differ to avoid singularities in the propagator
    computation).

    **Update Order**

    dftype = brainstate.environ.dftype()
    ditype = brainstate.environ.ditype()
    Each simulation step proceeds as follows (matching NEST's update order):

    1. Evolve voltage-dependent threshold component (``V_th_v``, ``V_th_dv``)
       using exact integration propagators
    2. Evolve membrane potential using exact integration
    3. Decay adaptive threshold components (``V_th_1``, ``V_th_2``)
    4. Decay synaptic currents and add incoming spike weights
    5. Check spike condition: if not refractory and :math:`V_m \geq V_{th}`,
       emit spike, increment threshold components, set refractory counter
    6. If refractory, decrement refractory counter
    7. Store buffered external currents for next step

    Implementation Notes
    --------------------

    - All time constants must be strictly positive and pairwise distinct:
      ``tau_m != tau_syn_ex``, ``tau_m != tau_syn_in``, ``tau_m != tau_v``,
      ``tau_v != tau_syn_ex``, ``tau_v != tau_syn_in``. This constraint arises
      from the exact integration scheme, which requires inverting matrices that
      become singular when time constants coincide.
    - Numerics may be unstable if time constants are very close but not exactly
      equal due to ill-conditioning of the propagator matrix computation.
    - Some parameter values in Table 1 of [4]_ are incorrect; see Table 4 of [5]_
      for corrected values.
    - The voltage-dependent threshold component requires significant computational
      overhead (additional propagator coefficients). Set ``beta = 0`` to disable
      this feature and recover ``mat2_psc_exp`` behavior.

    Parameters
    ----------
    in_size : int, tuple of int
        Population shape (number of neurons or spatial dimensions).
    E_L : Quantity, ndarray
        Resting membrane potential. Default: -70 mV.
    C_m : Quantity, ndarray
        Membrane capacitance. Must be strictly positive. Default: 200 pF.
    tau_m : Quantity, ndarray
        Membrane time constant. Must be strictly positive and differ from
        ``tau_syn_ex``, ``tau_syn_in``, and ``tau_v``. Default: 10 ms.
    t_ref : Quantity, ndarray
        Absolute refractory period (duration of spike emission block).
        Must be strictly positive. Default: 2 ms.
    tau_syn_ex : Quantity, ndarray
        Excitatory postsynaptic current time constant. Must be strictly positive
        and differ from ``tau_m``, ``tau_v``, and ``tau_syn_in``. Default: 1 ms.
    tau_syn_in : Quantity, ndarray
        Inhibitory postsynaptic current time constant. Must be strictly positive
        and differ from ``tau_m``, ``tau_v``, and ``tau_syn_ex``. Default: 3 ms.
    I_e : Quantity, ndarray
        Constant external input current. Default: 0 pA.
    tau_1 : Quantity, ndarray
        Time constant for short-timescale adaptive threshold component.
        Must be strictly positive. Default: 10 ms.
    tau_2 : Quantity, ndarray
        Time constant for long-timescale adaptive threshold component.
        Must be strictly positive. Default: 200 ms.
    alpha_1 : Quantity, ndarray
        Increment to ``V_th_1`` on each spike (fast adaptation amplitude).
        Default: 10 mV.
    alpha_2 : Quantity, ndarray
        Increment to ``V_th_2`` on each spike (slow adaptation amplitude).
        Default: 0 mV.
    beta : Quantity, ndarray
        Scaling coefficient for voltage-dependent threshold component.
        Units: 1/ms. Set to 0 to disable voltage-dependent threshold and
        recover ``mat2_psc_exp`` behavior. Default: 0 / ms.
    tau_v : Quantity, ndarray
        Time constant for voltage-dependent threshold component. Must be
        strictly positive and differ from ``tau_m``, ``tau_syn_ex``, and
        ``tau_syn_in``. Default: 5 ms.
    omega : Quantity, ndarray
        Resting spike threshold (absolute voltage, not relative to ``E_L``).
        Default: -65 mV.
    V_initializer : Callable, Quantity
        Initializer for membrane potential. Can be a ``braintools.init``
        initializer or a constant value. Default: Constant(-70 mV).
    spk_fun : Callable
        Surrogate gradient function for differentiable spike generation.
        Default: ``braintools.surrogate.ReluGrad()``.
    spk_reset : str
        Reset mode for surrogate gradient computation. Options: ``'soft'``
        (subtract threshold) or ``'hard'`` (stop gradient). Note: this does
        NOT affect the membrane potential dynamics (no reset occurs). It only
        affects gradient flow through the spike function. Default: ``'hard'``.
    ref_var : bool
        If True, expose a boolean ``refractory`` state variable indicating
        whether each neuron is currently in the refractory period.
        Default: False.
    name : str, optional
        Name of the neuron population.

    Parameter Mapping
    -----------------

    The following table maps BrainPy parameter names to their mathematical symbols
    and NEST equivalents:

    ==================== ================== =============================== ==========================================================
    **Parameter**        **Default**        **Math equivalent**             **Description**
    ==================== ================== =============================== ==========================================================
    ``in_size``          (required)                                         Population shape
    ``E_L``              -70 mV             :math:`E_L`                     Resting membrane potential
    ``C_m``              200 pF             :math:`C_m`                     Membrane capacitance
    ``tau_m``            10 ms              :math:`\tau_m`                  Membrane time constant
    ``t_ref``            2 ms               :math:`t_{ref}`                 Duration of absolute refractory period (no spiking)
    ``tau_syn_ex``       1 ms               :math:`\tau_{\mathrm{syn,ex}}`  Time constant of excitatory postsynaptic current
    ``tau_syn_in``       3 ms               :math:`\tau_{\mathrm{syn,in}}`  Time constant of inhibitory postsynaptic current
    ``I_e``              0 pA               :math:`I_e`                     Constant external input current
    ``tau_1``            10 ms              :math:`\tau_1`                  Short time constant of adaptive threshold
    ``tau_2``            200 ms             :math:`\tau_2`                  Long time constant of adaptive threshold
    ``alpha_1``          10 mV              :math:`\alpha_1`                Amplitude of short time threshold adaption
    ``alpha_2``          0 mV               :math:`\alpha_2`                Amplitude of long time threshold adaption
    ``beta``             0 1/ms             :math:`\beta`                   Scaling coefficient for voltage-dependent threshold
    ``tau_v``            5 ms               :math:`\tau_v`                  Time constant for voltage-dependent threshold component
    ``omega``            -65 mV             :math:`\omega`                  Resting spike threshold (absolute value, not relative to E_L)
    ``V_initializer``    Constant(-70 mV)                                   Membrane potential initializer
    ``spk_fun``          ReluGrad()                                         Surrogate spike function
    ``spk_reset``        ``'hard'``                                         Reset mode (not used for voltage; used in ``get_spike``)
    ``ref_var``          ``False``                                          If True, expose boolean refractory state
    ==================== ================== =============================== ==========================================================

    State Variables
    ---------------

    ========================= ===================== ====================================================
    **Variable**              **Type**              **Description**
    ========================= ===================== ====================================================
    ``V``                     ``HiddenState`` (mV)  Membrane potential (absolute)
    ``V_th_1``                ``ShortTermState``    Short-timescale adaptive threshold component (mV, relative to omega)
    ``V_th_2``                ``ShortTermState``    Long-timescale adaptive threshold component (mV, relative to omega)
    ``V_th_v``                ``ShortTermState``    Voltage-dependent threshold component (mV)
    ``V_th_dv``               ``ShortTermState``    Derivative of voltage-dependent threshold (mV)
    ``i_syn_ex``              ``ShortTermState``    Excitatory postsynaptic current (pA)
    ``i_syn_in``              ``ShortTermState``    Inhibitory postsynaptic current (pA)
    ``i_0``                   ``ShortTermState``    DC input current (pA)
    ``refractory_step_count`` ``ShortTermState``    Refractory countdown (integer steps)
    ``last_spike_time``       ``ShortTermState``    Time of last spike (ms)
    ``refractory``            ``ShortTermState``    Boolean refractory state (only if ``ref_var=True``)
    ========================= ===================== ====================================================

    Raises
    ------
    ValueError
        If ``C_m <= 0``.
    ValueError
        If any time constant is non-positive.
    ValueError
        If ``tau_m`` equals ``tau_syn_ex``, ``tau_syn_in``, or ``tau_v``
        (exact integration propagators become singular).
    ValueError
        If ``tau_v`` equals ``tau_syn_ex`` or ``tau_syn_in``
        (exact integration propagators become singular).

    Examples
    --------
    **Basic usage with voltage-dependent threshold:**

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> import brainstate as bstate
        >>>
        >>> # Create a neuron with voltage-dependent threshold
        >>> neuron = bst.amat2_psc_exp(
        ...     in_size=100,
        ...     beta=0.5 / u.ms,  # Enable voltage-dependent threshold
        ...     tau_v=5.0 * u.ms,
        ...     alpha_1=10.0 * u.mV,
        ...     alpha_2=0.5 * u.mV,
        ... )
        >>>
        >>> # Initialize states
        >>> with bstate.environ.context(dt=0.1 * u.ms):
        ...     neuron.init_all_states()
        ...     spike = neuron.update(x=500.0 * u.pA)  # Apply input current

    **Comparing with mat2_psc_exp (beta=0):**

    .. code-block:: python

        >>> # AMAT2 with beta=0 behaves like MAT2
        >>> amat2 = bst.amat2_psc_exp(in_size=10, beta=0.0 / u.ms)
        >>> mat2 = bst.mat2_psc_exp(in_size=10)
        >>>
        >>> # Both should produce similar dynamics
        >>> with bstate.environ.context(dt=0.1 * u.ms):
        ...     amat2.init_all_states()
        ...     mat2.init_all_states()

    **Simulating burst firing with strong voltage-dependent threshold:**

    .. code-block:: python

        >>> neuron = bst.amat2_psc_exp(
        ...     in_size=1,
        ...     beta=1.0 / u.ms,  # Strong voltage dependence
        ...     tau_v=3.0 * u.ms,  # Fast voltage tracking
        ...     alpha_1=15.0 * u.mV,  # Strong fast adaptation
        ...     tau_1=5.0 * u.ms,
        ... )
        >>>
        >>> with bstate.environ.context(dt=0.1 * u.ms):
        ...     neuron.init_all_states()
        ...     spikes = []
        ...     for _ in range(1000):
        ...         spk = neuron.update(x=600.0 * u.pA)
        ...         spikes.append(spk)

    References
    ----------
    .. [1] Rotter S and Diesmann M (1999). Exact simulation of
           time-invariant linear systems with applications to neuronal
           modeling. Biological Cybernetics 81:381-402.
           DOI: https://doi.org/10.1007/s004220050570
    .. [2] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State
           space analysis of synchronous spiking in cortical neural
           networks. Neurocomputing 38-40:565-571.
           DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
    .. [3] Kobayashi R, Tsubo Y and Shinomoto S (2009). Made-to-order
           spiking neuron model equipped with a multi-timescale adaptive
           threshold. Frontiers in Computational Neuroscience 3:9.
           DOI: https://doi.org/10.3389/neuro.10.009.2009
    .. [4] Yamauchi S, Kim H, Shinomoto S (2011). Elemental spiking neuron
           model for reproducing diverse firing patterns and predicting precise
           firing times. Frontiers in Computational Neuroscience 5:42.
           DOI: https://doi.org/10.3389/fncom.2011.00042
    .. [5] Heiberg T, Kriener B, Tetzlaff T, Einevoll GT, Plesser HE (2018).
           Firing-rate model for neurons with a broad repertoire of spiking
           behaviors. J Comput Neurosci 45:103.
           DOI: https://doi.org/10.1007/s10827-018-0693-9

    See Also
    --------
    mat2_psc_exp : Same model without voltage-dependent threshold component.
    aeif_psc_exp : Adaptive exponential integrate-and-fire with spike reset.
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 200. * u.pF,
        tau_m: ArrayLike = 10. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        tau_syn_ex: ArrayLike = 1. * u.ms,
        tau_syn_in: ArrayLike = 3. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        tau_1: ArrayLike = 10. * u.ms,
        tau_2: ArrayLike = 200. * u.ms,
        alpha_1: ArrayLike = 10. * u.mV,
        alpha_2: ArrayLike = 0. * u.mV,
        beta: ArrayLike = 0. / u.ms,
        tau_v: ArrayLike = 5. * u.ms,
        omega: ArrayLike = -65. * u.mV,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.tau_1 = braintools.init.param(tau_1, self.varshape)
        self.tau_2 = braintools.init.param(tau_2, self.varshape)
        self.alpha_1 = braintools.init.param(alpha_1, self.varshape)
        self.alpha_2 = braintools.init.param(alpha_2, self.varshape)
        self.beta = braintools.init.param(beta, self.varshape)
        self.tau_v = braintools.init.param(tau_v, self.varshape)
        self.omega = braintools.init.param(omega, self.varshape)

        self.V_initializer = V_initializer
        self.ref_var = ref_var
        self._validate_parameters()

        # Precompute refractory step count
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    @staticmethod
    def _to_numpy(x, unit):
        r"""Convert a quantity to a plain NumPy array in specified units.

        Parameters
        ----------
        x : Quantity, ndarray
            Input value with units.
        unit : Quantity
            Target unit for conversion.

        Returns
        -------
        ndarray
            Plain float64 NumPy array with units stripped.
        """
        dftype = brainstate.environ.dftype()
        return np.asarray(u.math.asarray(x / unit), dtype=dftype)

    @staticmethod
    def _broadcast_to_state(x_np: np.ndarray, shape):
        r"""Broadcast a parameter array to match state variable shape.

        Parameters
        ----------
        x_np : ndarray
            Parameter array (plain NumPy, no units).
        shape : tuple
            Target shape for broadcasting.

        Returns
        -------
        ndarray
            Broadcasted array with shape matching ``shape``.
        """
        return np.broadcast_to(x_np, shape)

    def _validate_parameters(self):
        r"""Validate model parameters for physical and numerical constraints.

        This method checks that:
        - Capacitance is strictly positive
        - All time constants are strictly positive
        - Time constants are pairwise distinct (required for exact integration)

        Raises
        ------
        ValueError
            If ``C_m <= 0``.
        ValueError
            If any time constant (``tau_m``, ``tau_syn_ex``, ``tau_syn_in``,
            ``tau_1``, ``tau_2``, ``tau_v``, ``t_ref``) is non-positive.
        ValueError
            If ``tau_m`` equals ``tau_syn_ex``, ``tau_syn_in``, or ``tau_v``
            (causes singularities in propagator matrix).
        ValueError
            If ``tau_v`` equals ``tau_syn_ex`` or ``tau_syn_in``
            (causes singularities in propagator matrix).
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.C_m, self.tau_m)):
            return

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        tau_m_val = self.tau_m
        tau_ex_val = self.tau_syn_ex
        tau_in_val = self.tau_syn_in
        tau_v_val = self.tau_v
        if np.any(tau_m_val <= 0.0 * u.ms) or np.any(tau_ex_val <= 0.0 * u.ms) or np.any(tau_in_val <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.t_ref <= 0.0 * u.ms):
            raise ValueError('Refractory time must be strictly positive.')
        if np.any(self.tau_1 <= 0.0 * u.ms) or np.any(self.tau_2 <= 0.0 * u.ms):
            raise ValueError('Adaptive threshold time constants must be strictly positive.')
        if np.any(tau_v_val <= 0.0 * u.ms):
            raise ValueError('tau_v must be strictly positive.')
        if np.any(tau_m_val == tau_ex_val) or np.any(tau_m_val == tau_in_val) or np.any(tau_m_val == tau_v_val):
            raise ValueError(
                'tau_m must differ from tau_syn_ex, tau_syn_in and tau_v. '
                'See note in documentation.'
            )
        if np.any(tau_v_val == tau_ex_val) or np.any(tau_v_val == tau_in_val):
            raise ValueError(
                'tau_v must differ from tau_syn_ex, tau_syn_in and tau_m. '
                'See note in documentation.'
            )

    def _precompute_constants(self):
        r"""Pre-compute static propagator coefficients from model parameters.

        All propagators depend only on fixed parameters and ``dt``, so they are
        computed once here and stored as JAX arrays for use in every ``update()``
        call.  This avoids re-running ``np.exp`` and the full propagator algebra on
        every time step and makes ``update()`` fully JIT-compatible.
        """
        dftype = brainstate.environ.dftype()
        dt_q = brainstate.environ.get_dt()
        h = float(u.get_mantissa(dt_q / u.ms))

        # Extract parameters as plain float64 numpy arrays
        taum = self._to_numpy(self.tau_m, u.ms)
        tauE = self._to_numpy(self.tau_syn_ex, u.ms)
        tauI = self._to_numpy(self.tau_syn_in, u.ms)
        tauV = self._to_numpy(self.tau_v, u.ms)
        c = self._to_numpy(self.C_m, u.pF)
        tau_1 = self._to_numpy(self.tau_1, u.ms)
        tau_2 = self._to_numpy(self.tau_2, u.ms)
        beta = self._to_numpy(self.beta, 1.0 / u.ms)

        eE = np.exp(-h / tauE)
        eI = np.exp(-h / tauI)
        em = np.exp(-h / taum)
        e1 = np.exp(-h / tau_1)
        e2 = np.exp(-h / tau_2)
        eV = np.exp(-h / tauV)

        P30 = (taum - em * taum) / c
        P31 = ((eE - em) * tauE * taum) / (c * (tauE - taum))
        P32 = ((eI - em) * tauI * taum) / (c * (tauI - taum))

        P60 = (beta * (em - eV) * taum * tauV) / (c * (taum - tauV))
        P61 = (beta * tauE * taum * tauV * (eV * (-tauE + taum) + em * (tauE - tauV) + eE * (-taum + tauV))) \
              / (c * (tauE - taum) * (tauE - tauV) * (taum - tauV))
        P62 = (beta * tauI * taum * tauV * (eV * (-tauI + taum) + em * (tauI - tauV) + eI * (-taum + tauV))) \
              / (c * (tauI - taum) * (tauI - tauV) * (taum - tauV))
        P63 = (beta * (-em + eV) * tauV) / (taum - tauV)

        P70 = (beta * taum * tauV * (em * taum * tauV - eV * (h * (taum - tauV) + taum * tauV))) \
              / (c * (taum - tauV) ** 2)
        P71 = (beta * tauE * taum * tauV
               * ((em * taum * (tauE - tauV) ** 2 - eE * tauE * (taum - tauV) ** 2) * tauV
                  - eV * (tauE - taum)
                  * (h * (tauE - tauV) * (taum - tauV) + tauE * taum * tauV - tauV ** 3))) \
              / (c * (tauE - taum) * (tauE - tauV) ** 2 * (taum - tauV) ** 2)
        P72 = (beta * tauI * taum * tauV
               * ((em * taum * (tauI - tauV) ** 2 - eI * tauI * (taum - tauV) ** 2) * tauV
                  - eV * (tauI - taum)
                  * (h * (tauI - tauV) * (taum - tauV) + tauI * taum * tauV - tauV ** 3))) \
              / (c * (tauI - taum) * (tauI - tauV) ** 2 * (taum - tauV) ** 2)
        P73 = (beta * tauV * (-(em * taum * tauV) + eV * (h * (taum - tauV) + taum * tauV))) \
              / (taum - tauV) ** 2
        P76 = eV * h

        def _j(arr):
            return jnp.asarray(arr, dtype=dftype)

        self._P11 = _j(eE)
        self._P22 = _j(eI)
        self._P33 = _j(em)
        self._P44 = _j(e1)
        self._P55 = _j(e2)
        self._P66 = _j(eV)
        self._P77 = _j(eV)
        self._P30 = _j(P30)
        self._P31 = _j(P31)
        self._P32 = _j(P32)
        self._P60 = _j(P60)
        self._P61 = _j(P61)
        self._P62 = _j(P62)
        self._P63 = _j(P63)
        self._P70 = _j(P70)
        self._P71 = _j(P71)
        self._P72 = _j(P72)
        self._P73 = _j(P73)
        self._P76 = _j(P76)

        # Pre-extract scalar / per-neuron constants used every step
        self._E_L_mV = _j(self._to_numpy(self.E_L, u.mV))
        self._I_e_pA = _j(self._to_numpy(self.I_e, u.pA))
        self._alpha_1_mV = _j(self._to_numpy(self.alpha_1, u.mV))
        self._alpha_2_mV = _j(self._to_numpy(self.alpha_2, u.mV))
        self._omega_rel_mV = _j(self._to_numpy(self.omega - self.E_L, u.mV))

[docs] def init_state(self, **kwargs): r"""Initialize all state variables. Creates and initializes all state variables including membrane potential, adaptive threshold components, voltage-dependent threshold components, synaptic currents, and refractory state. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables initialized: - ``V``: Membrane potential (from ``V_initializer``) - ``V_th_1``, ``V_th_2``: Adaptive threshold components (zero) - ``V_th_v``, ``V_th_dv``: Voltage-dependent threshold components (zero) - ``i_syn_ex``, ``i_syn_in``: Synaptic currents (zero) - ``i_0``: External current buffer (zero) - ``refractory_step_count``: Refractory counter (zero, not refractory) - ``last_spike_time``: Last spike time (large negative value) - ``refractory`` (if ``ref_var=True``): Boolean refractory state (False) """ ditype = brainstate.environ.ditype() V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros_like(u.math.asarray(V / u.mV)) self.V = brainstate.HiddenState(V) self.V_th_1 = brainstate.ShortTermState(zeros * u.mV) self.V_th_2 = brainstate.ShortTermState(zeros * u.mV) self.V_th_v = brainstate.ShortTermState(zeros * u.mV) self.V_th_dv = brainstate.ShortTermState(zeros * u.mV) self.i_syn_ex = brainstate.ShortTermState(zeros * u.pA) self.i_syn_in = brainstate.ShortTermState(zeros * u.pA) self.i_0 = brainstate.ShortTermState(zeros * u.pA) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory) self._precompute_constants()
[docs] def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): r"""Compute spike output using surrogate gradient function. Applies the surrogate gradient function to the scaled distance between membrane potential and adaptive threshold. This enables differentiable spike generation for gradient-based learning. Parameters ---------- V : Quantity, ndarray, optional Membrane potential (absolute voltage). If None, uses current ``self.V.value``. Shape: ``(*varshape,)`` or ``(batch_size, *varshape)``. V_th : Quantity, ndarray, optional Total spike threshold (absolute voltage). If None, computed as ``omega + V_th_1 + V_th_2 + V_th_v``. Shape: same as ``V``. Returns ------- ndarray Spike output (pseudo-probability in [0, 1] from surrogate function). Shape: same as ``V``. Notes ----- The spike function is applied to the scaled voltage distance: .. math:: s = \\mathrm{spk\\_fun}\\left(\\frac{V - V_{th}}{|\\omega - E_L|}\\right) The scaling factor ``|omega - E_L|`` normalizes the voltage distance to the typical threshold range, improving numerical stability across different parameter regimes. """ V = self.V.value if V is None else V if V_th is None: V_th = self.omega + self.V_th_1.value + self.V_th_2.value + self.V_th_v.value v_scaled = (V - V_th) / u.math.abs(self.omega - self.E_L) return self.spk_fun(v_scaled)
[docs] def update(self, x=0. * u.pA): r"""Perform one simulation time step. Integrates membrane potential, synaptic currents, and adaptive threshold components for one time step using the exact integration scheme. Detects spike emission and updates refractory state. The membrane potential is NOT reset after spikes. This method follows the NEST update order for ``amat2_psc_exp``: 1. Evolve voltage-dependent threshold components (``V_th_v``, ``V_th_dv``) using exact propagators that depend on all synaptic and membrane currents 2. Evolve membrane potential using exact integration 3. Decay time-dependent threshold components (``V_th_1``, ``V_th_2``) 4. Decay synaptic currents and add incoming spike weights 5. Detect spikes: if not refractory and ``V >= omega + V_th_1 + V_th_2 + V_th_v``: - Increment threshold components by ``alpha_1`` and ``alpha_2`` - Set refractory counter to ``ceil(t_ref / dt)`` - Record spike time 6. If refractory, decrement refractory counter 7. Buffer external currents for next step Parameters ---------- x : Quantity, ndarray, optional External input current for the current time step. This current is buffered and applied in the NEXT time step (one-step delay, following NEST convention). Shape: scalar, ``(*varshape,)``, or ``(batch_size, *varshape)``. Default: 0 pA. Returns ------- ndarray Spike output from surrogate gradient function. Values in [0, 1] represent pseudo-spike probabilities. Actual spike detection (for threshold increment and refractory period) uses hard threshold crossing. Shape: same as state variables. Notes ----- **Exact Integration Propagators** The model uses closed-form propagators for linear ODEs [1]_. For a single time step of size ``h``, the propagators are: **Independent exponential decays:** .. math:: P_{11} &= e^{-h/\\tau_{syn,ex}} \\quad (i\\_syn\\_ex) \\\\ P_{22} &= e^{-h/\\tau_{syn,in}} \\quad (i\\_syn\\_in) \\\\ P_{33} &= e^{-h/\\tau_m} \\quad (V_m) \\\\ P_{44} &= e^{-h/\\tau_1} \\quad (V_{th,1}) \\\\ P_{55} &= e^{-h/\\tau_2} \\quad (V_{th,2}) \\\\ P_{66} &= e^{-h/\\tau_v} \\quad (V_{th,dv}) \\\\ P_{77} &= e^{-h/\\tau_v} \\quad (V_{th,v}) **Membrane potential coupling to currents:** .. math:: P_{30} &= \\frac{\\tau_m}{C_m}(1 - e^{-h/\\tau_m}) \\\\ P_{31} &= \\frac{\\tau_m \\tau_{syn,ex}}{C_m(\\tau_{syn,ex} - \\tau_m)} (e^{-h/\\tau_{syn,ex}} - e^{-h/\\tau_m}) \\\\ P_{32} &= \\frac{\\tau_m \\tau_{syn,in}}{C_m(\\tau_{syn,in} - \\tau_m)} (e^{-h/\\tau_{syn,in}} - e^{-h/\\tau_m}) **Voltage-dependent threshold (derivative component, ``V_th_dv``):** .. math:: P_{60} &= \\frac{\\beta \\tau_m \\tau_v}{C_m(\\tau_m - \\tau_v)} (e^{-h/\\tau_m} - e^{-h/\\tau_v}) \\\\ P_{61} &= \\frac{\\beta \\tau_{syn,ex} \\tau_m \\tau_v} {C_m(\\tau_{syn,ex}-\\tau_m)(\\tau_{syn,ex}-\\tau_v)(\\tau_m-\\tau_v)} \\times \\\\ &\\quad (e^{-h/\\tau_v}(-\\tau_{syn,ex}+\\tau_m) + e^{-h/\\tau_m}(\\tau_{syn,ex}-\\tau_v) + e^{-h/\\tau_{syn,ex}}(-\\tau_m+\\tau_v)) \\\\ P_{62} &= \\text{[similar for inhibitory synapse]} \\\\ P_{63} &= \\frac{\\beta \\tau_v}{\\tau_m - \\tau_v} (e^{-h/\\tau_v} - e^{-h/\\tau_m}) **Voltage-dependent threshold (integrated component, ``V_th_v``):** .. math:: P_{70} &= \\frac{\\beta \\tau_m \\tau_v}{C_m(\\tau_m-\\tau_v)^2} (e^{-h/\\tau_m} \\tau_m \\tau_v - e^{-h/\\tau_v}(h(\\tau_m-\\tau_v) + \\tau_m \\tau_v)) \\\\ P_{71} &= \\text{[complex expression, see code]} \\\\ P_{72} &= \\text{[complex expression, see code]} \\\\ P_{73} &= \\frac{\\beta \\tau_v}{(\\tau_m-\\tau_v)^2} (e^{-h/\\tau_v}(h(\\tau_m-\\tau_v)+\\tau_m\\tau_v) - e^{-h/\\tau_m}\\tau_m\\tau_v) \\\\ P_{76} &= h e^{-h/\\tau_v} These propagators are recomputed at each time step to accommodate spatially-varying parameters (different time constants for different neurons). **Update Equations** The state update proceeds as: .. math:: V_{th,v}^{new} &= P_{70} (I_e + I_0) + P_{71} I_{syn,ex} + P_{72} I_{syn,in} + P_{73} V_m + P_{76} V_{th,dv} + P_{77} V_{th,v} \\\\ V_{th,dv}^{new} &= P_{60} (I_e + I_0) + P_{61} I_{syn,ex} + P_{62} I_{syn,in} + P_{63} V_m + P_{66} V_{th,dv} \\\\ V_m^{new} &= P_{30} (I_e + I_0) + P_{31} I_{syn,ex} + P_{32} I_{syn,in} + P_{33} V_m \\\\ V_{th,1}^{new} &= P_{44} V_{th,1} \\\\ V_{th,2}^{new} &= P_{55} V_{th,2} \\\\ I_{syn,ex}^{new} &= P_{11} I_{syn,ex} + \\Delta I_{ex} \\\\ I_{syn,in}^{new} &= P_{22} I_{syn,in} + \\Delta I_{in} where :math:`\\Delta I_{ex}` and :math:`\\Delta I_{in}` are the summed weights of excitatory and inhibitory spikes arriving in the current step. **Spike Detection and Threshold Increment** Spikes are detected when: .. math:: V_m \\geq \\omega + V_{th,1} + V_{th,2} + V_{th,v} \\quad \\text{and} \\quad r = 0 where :math:`r` is the refractory counter. On spike detection: .. math:: V_{th,1} &\\leftarrow V_{th,1} + \\alpha_1 \\\\ V_{th,2} &\\leftarrow V_{th,2} + \\alpha_2 \\\\ r &\\leftarrow \\lceil t_{ref} / dt \\rceil **No Membrane Reset** Unlike many spiking neuron models, the membrane potential is NOT reset after a spike. It continues to integrate according to the differential equation. Adaptation is achieved solely through threshold elevation. **Input Handling** - **Spike inputs**: Accessed via ``self.sum_delta_inputs()`` which aggregates weights from all connected projections. Positive weights add to excitatory current, negative weights to inhibitory current. - **Current inputs**: Accessed via ``self.sum_current_inputs(x, V)`` which sums the external current ``x`` and any currents from projections. This current is buffered in ``i_0`` and applied in the NEXT time step. **Surrogate Gradient** The return value uses the surrogate gradient function for differentiability. The actual spike condition (hard threshold) is evaluated separately and used for threshold increment and refractory logic. This allows gradient-based learning while maintaining biological spike semantics. Warnings -------- - If time constants are very close but not exactly equal, numerical instability may occur in propagator computation due to near-singularities. - The one-step delay in external current application (``i_0``) is required for consistency with NEST and exact integration numerics. - Setting ``beta`` to large values can make the voltage-dependent threshold very sensitive to voltage fluctuations, potentially causing numerical issues. References ---------- .. [6] Rotter S and Diesmann M (1999). Exact simulation of time-invariant linear systems with applications to neuronal modeling. Biological Cybernetics 81:381-402. DOI: https://doi.org/10.1007/s004220050570 """ t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() ditype = brainstate.environ.ditype() # --- Read state as dimensionless JAX arrays (unit stripped, no numpy) --- V_rel = (self.V.value - self.E_L) / u.mV # mV - mV / mV → float V_th_1 = self.V_th_1.value / u.mV V_th_2 = self.V_th_2.value / u.mV V_th_v = self.V_th_v.value / u.mV V_th_dv = self.V_th_dv.value / u.mV i_syn_ex = self.i_syn_ex.value / u.pA i_syn_in = self.i_syn_in.value / u.pA i_0 = self.i_0.value / u.pA r = self.refractory_step_count.value # integer JAX array # --- Use pre-computed propagators and static parameters --- I_e = self._I_e_pA # --- Get spike inputs (dimensionless pA) --- w_all = self.sum_delta_inputs(0. * u.pA) / u.pA w_ex = u.math.where(w_all > 0.0, w_all, 0.0) w_in = u.math.where(w_all < 0.0, w_all, 0.0) # --- Get current inputs (one-step delayed, broadcast to varshape) --- i_0_next = self.sum_current_inputs(x, self.V.value) / u.pA + u.math.zeros(self.varshape) # === NEST update ordering (amat2_psc_exp.cpp update() lines 375-421) === # Step 1: Evolve voltage-dependent threshold (V_th_v and V_th_dv). # V_th_v uses OLD V_th_dv, so compute both from current state first. V_th_v_new = ((I_e + i_0) * self._P70 + i_syn_ex * self._P71 + i_syn_in * self._P72 + V_rel * self._P73 + V_th_dv * self._P76 + V_th_v * self._P77) V_th_dv_new = ((I_e + i_0) * self._P60 + i_syn_ex * self._P61 + i_syn_in * self._P62 + V_rel * self._P63 + V_th_dv * self._P66) V_th_v = V_th_v_new V_th_dv = V_th_dv_new # Step 2: Evolve membrane potential V_rel = ((I_e + i_0) * self._P30 + i_syn_ex * self._P31 + i_syn_in * self._P32 + V_rel * self._P33) # Step 3: Decay adaptive threshold components V_th_1 = V_th_1 * self._P44 V_th_2 = V_th_2 * self._P55 # Step 4: Decay synaptic currents and add incoming spikes i_syn_ex = i_syn_ex * self._P11 + w_ex i_syn_in = i_syn_in * self._P22 + w_in # Step 5-6: Spike detection (no voltage reset!) not_refractory = r == 0 spike_cond = not_refractory & (V_rel >= self._omega_rel_mV + V_th_1 + V_th_2 + V_th_v) # On spike: jump threshold components, set refractory counter V_th_1 = u.math.where(spike_cond, V_th_1 + self._alpha_1_mV, V_th_1) V_th_2 = u.math.where(spike_cond, V_th_2 + self._alpha_2_mV, V_th_2) r = u.math.where( spike_cond, u.math.asarray(self.ref_count, dtype=ditype), u.math.where(not_refractory, r, r - 1), ) # --- Write back state variables --- self.V.value = (V_rel + self._E_L_mV) * u.mV self.V_th_1.value = V_th_1 * u.mV self.V_th_2.value = V_th_2 * u.mV self.V_th_v.value = V_th_v * u.mV self.V_th_dv.value = V_th_dv * u.mV self.i_syn_ex.value = i_syn_ex * u.pA self.i_syn_in.value = i_syn_in * u.pA self.i_0.value = i_0_next * u.pA self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t + dt_q, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) # Return spike output via surrogate gradient. V_th_abs = self._omega_rel_mV + V_th_1 + V_th_2 + V_th_v + self._E_L_mV V_out = u.math.where(spike_cond, V_th_abs + 1e-12, V_th_abs - 1e-12) return self.get_spike(V_out * u.mV, V_th_abs * u.mV)