Source code for brainpy_state._nest.mat2_psc_exp

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer

__all__ = [
    'mat2_psc_exp',
]


class mat2_psc_exp(NESTNeuron):
    r"""NEST-compatible ``mat2_psc_exp`` neuron model.

    Non-resetting leaky integrate-and-fire neuron model with exponential
    postsynaptic currents and a two-timescale adaptive threshold.

    **1. Model Overview**

    ``mat2_psc_exp`` implements a leaky integrate-and-fire model with exponential
    shaped postsynaptic currents (PSCs) and a multi-timescale adaptive threshold
    (MAT) [3]_. Key features:

    - **No voltage reset**: The membrane potential continues to integrate through
      spikes, providing biological realism for high-firing-rate regimes.
    - **Two-timescale threshold adaptation**: Separate fast (τ₁) and slow (τ₂)
      threshold components capture short-term spike frequency adaptation and
      long-term gain control.
    - **Absolute refractory period**: Prevents multiple spikes within ``t_ref``
      without clamping the membrane potential.
    - **Exact integration**: Subthreshold dynamics use the exponential Euler
      propagator [1]_ for numerical stability.

    **2. Mathematical Formulation**

    **2.1 Subthreshold Membrane Dynamics**

    The membrane potential evolves according to:

    .. math::

       \frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m}
       + \frac{I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}} + I_e + I_0}{C_m}

    where:

    - :math:`V_m` -- membrane potential (absolute voltage)
    - :math:`E_L` -- resting potential
    - :math:`\tau_m = C_m / g_L` -- membrane time constant
    - :math:`I_{\mathrm{syn,ex}}, I_{\mathrm{syn,in}}` -- synaptic currents
    - :math:`I_e` -- constant external current
    - :math:`I_0` -- buffered step current input (updated each time step)

    **2.2 Synaptic Currents**

    Exponentially decaying currents with infinitely fast rise:

    .. math::

       \frac{dI_{\mathrm{syn,ex}}}{dt} = -\frac{I_{\mathrm{syn,ex}}}{\tau_{\mathrm{syn,ex}}}
       \qquad
       \frac{dI_{\mathrm{syn,in}}}{dt} = -\frac{I_{\mathrm{syn,in}}}{\tau_{\mathrm{syn,in}}}

    Incoming spike weights are added instantaneously: :math:`I_{\mathrm{syn}} \leftarrow I_{\mathrm{syn}} + w`.

    **2.3 Adaptive Threshold**

    The effective spike threshold is the sum of a baseline and two decaying components:

    .. math::

       V_{th}(t) = \omega + V_{th,1}(t) + V_{th,2}(t)

    where:

    .. math::

       \frac{dV_{th,1}}{dt} = -\frac{V_{th,1}}{\tau_1}
       \qquad
       \frac{dV_{th,2}}{dt} = -\frac{V_{th,2}}{\tau_2}

    On each spike at time :math:`t_{\text{spike}}`:

    .. math::

       V_{th,1}(t_{\text{spike}}^+) = V_{th,1}(t_{\text{spike}}^-) + \alpha_1
       \qquad
       V_{th,2}(t_{\text{spike}}^+) = V_{th,2}(t_{\text{spike}}^-) + \alpha_2

    **2.4 Spike Emission**

    A spike is emitted when:

    .. math::

       V_m \geq \omega + V_{th,1} + V_{th,2}
       \quad \text{and} \quad
       t - t_{\text{last\_spike}} \geq t_{\text{ref}}

    After spiking:

    - Threshold components jump by :math:`\alpha_1, \alpha_2`
    - Refractory counter is set to :math:`\lceil t_{\text{ref}} / \Delta t \rceil`
    - **Membrane potential is NOT reset** (continues integrating)

    **3. Numerical Integration**

    The model uses exact integration for the linear subthreshold system [1]_.
    For a time step :math:`h = \Delta t`:

    **3.1 Membrane Potential Propagators**

    .. math::

       V_m(t+h) &= V_m(t) e^{-h/\tau_m} + E_L (1 - e^{-h/\tau_m}) \\
                &\quad + P_{21}^{\text{ex}} I_{\text{syn,ex}}(t)
                + P_{21}^{\text{in}} I_{\text{syn,in}}(t)
                + P_{20} (I_e + I_0)

    where:

    .. math::

       P_{21}^{\text{ex}} &= -\frac{\tau_m}{C_m (1 - \tau_m/\tau_{\text{syn,ex}})}
                            e^{-h/\tau_{\text{syn,ex}}}
                            (e^{h(1/\tau_{\text{syn,ex}} - 1/\tau_m)} - 1) \\
       P_{21}^{\text{in}} &= -\frac{\tau_m}{C_m (1 - \tau_m/\tau_{\text{syn,in}})}
                            e^{-h/\tau_{\text{syn,in}}}
                            (e^{h(1/\tau_{\text{syn,in}} - 1/\tau_m)} - 1) \\
       P_{20} &= -\frac{\tau_m}{C_m} (e^{-h/\tau_m} - 1)

    **3.2 Synaptic and Threshold Propagators**

    .. math::

       I_{\text{syn}}(t+h) &= I_{\text{syn}}(t) e^{-h/\tau_{\text{syn}}} + w_{\text{spike}} \\
       V_{th,1}(t+h) &= V_{th,1}(t) e^{-h/\tau_1} \\
       V_{th,2}(t+h) &= V_{th,2}(t) e^{-h/\tau_2}

    **3.3 Numerical Stability Constraint**

    The propagators become ill-conditioned when :math:`\tau_m \approx \tau_{\text{syn,ex}}`
    or :math:`\tau_m \approx \tau_{\text{syn,in}}` due to division by
    :math:`(1 - \tau_m/\tau_{\text{syn}})`. The constructor enforces strict inequality.

    **4. Update Order (NEST-Compatible)**

    For each time step (matching NEST's ``mat2_psc_exp.cpp``):

    1. **Integrate membrane potential** using exact propagators
    2. **Decay adaptive threshold components** (:math:`V_{th,1}`, :math:`V_{th,2}`)
    3. **Decay synaptic currents** and add incoming spike weights
    4. **Detect spikes**: if not refractory and :math:`V_m \geq V_{th}`, emit spike,
       jump threshold components, reset refractory counter
    5. **Update refractory state**: decrement counter if active
    6. **Buffer current inputs** for next step (:math:`I_0`)

    **5. Surrogate Gradient Handling**

    For differentiable training, the output spike signal passes through a surrogate
    gradient function (``spk_fun``). The voltage is scaled relative to the adaptive
    threshold:

    .. math::

       v_{\text{scaled}} = \frac{V_m - V_{th}}{|\omega - E_L|}

    where the denominator provides a characteristic voltage scale (~19 mV with defaults).

    Parameters
    ----------
    in_size : int, tuple of int
        Shape of the neuron population. Can be an integer (1D) or tuple (multi-dimensional).
    E_L : Quantity, ArrayLike, optional
        Resting membrane potential (default: -70 mV). Broadcastable to ``in_size``.
    C_m : Quantity, ArrayLike, optional
        Membrane capacitance (default: 100 pF). Must be strictly positive.
    tau_m : Quantity, ArrayLike, optional
        Membrane time constant (default: 5 ms). Must be strictly positive and differ
        from ``tau_syn_ex`` and ``tau_syn_in`` to avoid numerical degeneracy.
    t_ref : Quantity, ArrayLike, optional
        Duration of absolute refractory period (default: 2 ms). Must be strictly positive.
    tau_syn_ex : Quantity, ArrayLike, optional
        Time constant of excitatory postsynaptic current (default: 1 ms). Must be
        strictly positive and differ from ``tau_m``.
    tau_syn_in : Quantity, ArrayLike, optional
        Time constant of inhibitory postsynaptic current (default: 3 ms). Must be
        strictly positive and differ from ``tau_m``.
    I_e : Quantity, ArrayLike, optional
        Constant external input current (default: 0 pA). Broadcastable to ``in_size``.
    tau_1 : Quantity, ArrayLike, optional
        Short time constant of adaptive threshold (default: 10 ms). Must be strictly positive.
    tau_2 : Quantity, ArrayLike, optional
        Long time constant of adaptive threshold (default: 200 ms). Must be strictly positive.
    alpha_1 : Quantity, ArrayLike, optional
        Amplitude of short-timescale threshold jump on spike (default: 37 mV).
    alpha_2 : Quantity, ArrayLike, optional
        Amplitude of long-timescale threshold jump on spike (default: 2 mV).
    omega : Quantity, ArrayLike, optional
        Resting spike threshold (default: -51 mV). This is an **absolute voltage**,
        not relative to ``E_L``. With defaults, the threshold is 19 mV above resting.
    V_initializer : Callable, optional
        Initializer for membrane potential (default: Constant(-70 mV)). Called as
        ``V_initializer(shape, batch_size)`` to produce initial voltages.
    spk_fun : Callable, optional
        Surrogate gradient function for spike generation (default: ReluGrad()).
        Maps scaled voltage to differentiable spike signal.
    spk_reset : str, optional
        Reset mode for spike output (default: ``'hard'``). Options: ``'hard'`` (stop gradient)
        or ``'soft'`` (preserve gradient). Does NOT affect membrane voltage reset
        (which never occurs in this model).
    ref_var : bool, optional
        If True, expose a boolean ``refractory`` state variable (default: False).
    name : str, optional
        Name of the neuron population. If None, auto-generated.

    Parameter Mapping
    -----------------
    Correspondence between constructor parameters and mathematical symbols:

    ==================== ================== =============================== ==========================================================
    **Parameter**        **Default**        **Math Symbol**                 **Description**
    ==================== ================== =============================== ==========================================================
    ``in_size``          (required)         —                               Population shape
    ``E_L``              -70 mV             :math:`E_L`                     Resting membrane potential
    ``C_m``              100 pF             :math:`C_m`                     Membrane capacitance
    ``tau_m``            5 ms               :math:`\tau_m`                  Membrane time constant
    ``t_ref``            2 ms               :math:`t_{\text{ref}}`          Duration of absolute refractory period
    ``tau_syn_ex``       1 ms               :math:`\tau_{\text{syn,ex}}`    Time constant of excitatory PSC
    ``tau_syn_in``       3 ms               :math:`\tau_{\text{syn,in}}`    Time constant of inhibitory PSC
    ``I_e``              0 pA               :math:`I_e`                     Constant external input current
    ``tau_1``            10 ms              :math:`\tau_1`                  Short time constant of adaptive threshold
    ``tau_2``            200 ms             :math:`\tau_2`                  Long time constant of adaptive threshold
    ``alpha_1``          37 mV              :math:`\alpha_1`                Amplitude of short-timescale threshold jump
    ``alpha_2``          2 mV               :math:`\alpha_2`                Amplitude of long-timescale threshold jump
    ``omega``            -51 mV             :math:`\omega`                  Resting spike threshold (absolute voltage)
    ``V_initializer``    Constant(-70 mV)   —                               Membrane potential initializer
    ``spk_fun``          ReluGrad()         —                               Surrogate spike function
    ``spk_reset``        ``'hard'``         —                               Reset mode (for gradient handling)
    ``ref_var``          ``False``          —                               If True, expose ``refractory`` boolean state
    ==================== ================== =============================== ==========================================================

    State Variables
    ---------------
    After ``init_state()``, the following state variables are available:

    ========================= ===================== ====================================================
    **Variable**              **Type**              **Description**
    ========================= ===================== ====================================================
    ``V``                     ``HiddenState`` (mV)  Membrane potential (absolute voltage)
    ``V_th_1``                ``ShortTermState``    Short-timescale adaptive threshold component (mV)
    ``V_th_2``                ``ShortTermState``    Long-timescale adaptive threshold component (mV)
    ``i_syn_ex``              ``ShortTermState``    Excitatory postsynaptic current (pA)
    ``i_syn_in``              ``ShortTermState``    Inhibitory postsynaptic current (pA)
    ``i_0``                   ``ShortTermState``    Buffered DC input current (pA, one-step delayed)
    ``refractory_step_count`` ``ShortTermState``    Refractory countdown (integer steps remaining)
    ``last_spike_time``       ``ShortTermState``    Time of last spike (ms)
    ``refractory``            ``ShortTermState``    Boolean refractory flag (only if ``ref_var=True``)
    ========================= ===================== ====================================================

    Raises
    ------
    ValueError
        If ``C_m <= 0``, ``tau_m <= 0``, ``tau_syn_ex <= 0``, ``tau_syn_in <= 0``,
        ``t_ref <= 0``, ``tau_1 <= 0``, or ``tau_2 <= 0``.
    ValueError
        If ``tau_m == tau_syn_ex`` or ``tau_m == tau_syn_in`` (numerical degeneracy).

    See Also
    --------
    amat2_psc_exp : Variant with exponential synaptic currents and additional state variables.
    iaf_psc_exp : Standard LIF with voltage reset (no adaptive threshold).
    aeif_psc_exp : Adaptive exponential IF with spike-triggered adaptation current.

    Notes
    -----
    **Biological Interpretation:**
    The MAT model captures spike frequency adaptation without explicit adaptation currents.
    The fast threshold component (τ₁ ~ 10 ms) models sodium channel inactivation,
    while the slow component (τ₂ ~ 200 ms) models calcium-dependent potassium currents.

    **Comparison to NEST:**
    This implementation matches NEST's ``mat2_psc_exp`` update order (see NEST 3.7+
    ``mat2_psc_exp.cpp``). Key differences:

    - **Surrogate gradients**: brainpy.state adds differentiable spike signals via ``spk_fun``
      for gradient-based learning; NEST uses exact spike times.
    - **Batch dimension**: brainpy.state supports batch processing for parallel simulations;
      NEST operates on single neuron instances.
    - **Precision**: brainpy.state uses float32 (JAX default); NEST uses float64. Minor
      numerical differences may occur for long simulations.

    **Performance Notes:**

    - Propagator computation (exponentials, ``expm1``) dominates runtime for small populations.
    - For large populations (>10k neurons), vectorized operations amortize this cost.
    - Use ``jax.jit`` compilation for optimal performance.

    References
    ----------
    .. [1] Rotter S and Diesmann M (1999). Exact simulation of time-invariant linear
           systems with applications to neuronal modeling. Biological Cybernetics 81:381-402.
           DOI: https://doi.org/10.1007/s004220050570
    .. [2] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State space analysis of
           synchronous spiking in cortical neural networks. Neurocomputing 38-40:565-571.
           DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
    .. [3] Kobayashi R, Tsubo Y and Shinomoto S (2009). Made-to-order spiking neuron model
           equipped with a multi-timescale adaptive threshold. Frontiers in Computational
           Neuroscience 3:9. DOI: https://doi.org/10.3389/neuro.10.009.2009

    Examples
    --------
    **Basic Usage:**

    .. code-block:: python

       >>> import brainpy.state as bp
       >>> import saiunit as u
       >>> # Create a population of 100 MAT neurons
       >>> neurons = bp.mat2_psc_exp(100, tau_1=10*u.ms, tau_2=200*u.ms)
       >>> neurons.init_all_states()
       >>> # Inject step current and simulate
       >>> with brainstate.environ.context(dt=0.1*u.ms):
       ...     spikes = neurons.update(500*u.pA)  # 500 pA step current

    **Demonstrating Adaptive Threshold:**

    .. code-block:: python

       >>> # Single neuron with strong adaptation
       >>> neuron = bp.mat2_psc_exp(1, alpha_1=50*u.mV, alpha_2=5*u.mV)
       >>> neuron.init_all_states()
       >>> with brainstate.environ.context(dt=0.1*u.ms):
       ...     V_trace = []
       ...     for _ in range(1000):  # 100 ms simulation
       ...         spk = neuron.update(800*u.pA)
       ...         V_trace.append(neuron.V.value)
       >>> # Plot V_trace to observe spike frequency adaptation

    **Network with Excitatory/Inhibitory Synapses:**

    .. code-block:: python

       >>> exc = bp.mat2_psc_exp(800)
       >>> inh = bp.mat2_psc_exp(200, tau_syn_ex=0.5*u.ms)
       >>> exc.init_all_states()
       >>> inh.init_all_states()
       >>> # Connect populations via projections (see brainpy.state.Projection)
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 100. * u.pF,
        tau_m: ArrayLike = 5. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        tau_syn_ex: ArrayLike = 1. * u.ms,
        tau_syn_in: ArrayLike = 3. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        tau_1: ArrayLike = 10. * u.ms,
        tau_2: ArrayLike = 200. * u.ms,
        alpha_1: ArrayLike = 37. * u.mV,
        alpha_2: ArrayLike = 2. * u.mV,
        omega: ArrayLike = -51. * u.mV,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.tau_1 = braintools.init.param(tau_1, self.varshape)
        self.tau_2 = braintools.init.param(tau_2, self.varshape)
        self.alpha_1 = braintools.init.param(alpha_1, self.varshape)
        self.alpha_2 = braintools.init.param(alpha_2, self.varshape)
        self.omega = braintools.init.param(omega, self.varshape)

        self.V_initializer = V_initializer
        self.ref_var = ref_var
        self._validate_parameters()

        # Precompute refractory step count
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

        # Pre-compute all propagator constants for JIT-compatible update()
        self._precompute_constants()

    @staticmethod
    def _to_numpy(x, unit):
        dftype = brainstate.environ.dftype()
        return np.asarray(u.math.asarray(x / unit), dtype=dftype)

    @staticmethod
    def _broadcast_to_state(x_np: np.ndarray, shape):
        return np.broadcast_to(x_np, shape)

    def _precompute_constants(self):
        """Pre-compute time-step propagator coefficients as JAX arrays (called once at init)."""
        dftype = brainstate.environ.dftype()
        dt = brainstate.environ.get_dt()
        h = float(np.asarray(u.math.asarray(dt / u.ms)))

        tau_m = self._to_numpy(self.tau_m, u.ms)
        tau_ex = self._to_numpy(self.tau_syn_ex, u.ms)
        tau_in = self._to_numpy(self.tau_syn_in, u.ms)
        C_m = self._to_numpy(self.C_m, u.pF)
        tau_1 = self._to_numpy(self.tau_1, u.ms)
        tau_2 = self._to_numpy(self.tau_2, u.ms)

        self._P11ex = jnp.asarray(np.exp(-h / tau_ex), dtype=dftype)
        self._P11in = jnp.asarray(np.exp(-h / tau_in), dtype=dftype)
        self._P22_expm1 = jnp.asarray(np.expm1(-h / tau_m), dtype=dftype)
        self._P21ex = jnp.asarray(
            -tau_m / (C_m * (1.0 - tau_m / tau_ex)) * np.exp(-h / tau_ex)
            * np.expm1(h * (1.0 / tau_ex - 1.0 / tau_m)),
            dtype=dftype,
        )
        self._P21in = jnp.asarray(
            -tau_m / (C_m * (1.0 - tau_m / tau_in)) * np.exp(-h / tau_in)
            * np.expm1(h * (1.0 / tau_in - 1.0 / tau_m)),
            dtype=dftype,
        )
        self._P20 = jnp.asarray(-tau_m / C_m * np.expm1(-h / tau_m), dtype=dftype)
        self._P11th = jnp.asarray(np.exp(-h / tau_1), dtype=dftype)
        self._P22th = jnp.asarray(np.exp(-h / tau_2), dtype=dftype)

        self._E_L_mV = jnp.asarray(self._to_numpy(self.E_L, u.mV), dtype=dftype)
        self._I_e_pA = jnp.asarray(self._to_numpy(self.I_e, u.pA), dtype=dftype)
        self._alpha_1_mV = jnp.asarray(self._to_numpy(self.alpha_1, u.mV), dtype=dftype)
        self._alpha_2_mV = jnp.asarray(self._to_numpy(self.alpha_2, u.mV), dtype=dftype)
        self._omega_rel_mV = jnp.asarray(self._to_numpy(self.omega - self.E_L, u.mV), dtype=dftype)

    def _validate_parameters(self):
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.C_m, self.tau_m)):
            return

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.tau_m <= 0.0 * u.ms):
            raise ValueError('Membrane time constant must be strictly positive.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('Synaptic time constants must be strictly positive.')
        if np.any(self.t_ref <= 0.0 * u.ms):
            raise ValueError('Refractory time must be strictly positive.')
        if np.any(self.tau_1 <= 0.0 * u.ms) or np.any(self.tau_2 <= 0.0 * u.ms):
            raise ValueError('Adaptive threshold time constants must be strictly positive.')
        if np.any(self.tau_m == self.tau_syn_ex) or np.any(self.tau_m == self.tau_syn_in):
            raise ValueError(
                'Membrane and synapse time constant(s) must differ. '
                'See note in documentation.'
            )

[docs] def init_state(self, **kwargs): ditype = brainstate.environ.ditype() V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros_like(u.math.asarray(V / u.mV)) self.V = brainstate.HiddenState(V) self.V_th_1 = brainstate.ShortTermState(zeros * u.mV) self.V_th_2 = brainstate.ShortTermState(zeros * u.mV) self.i_syn_ex = brainstate.ShortTermState(zeros * u.pA) self.i_syn_in = brainstate.ShortTermState(zeros * u.pA) self.i_0 = brainstate.ShortTermState(zeros * u.pA) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): r"""Compute surrogate gradient spike signal. Parameters ---------- V : Quantity, ArrayLike, optional Membrane potential (mV). If None, uses current state ``self.V.value``. V_th : Quantity, ArrayLike, optional Effective threshold (mV). If None, computes as ``omega + V_th_1 + V_th_2``. Returns ------- spike : ArrayLike Differentiable spike signal from surrogate function. Shape matches ``V``. Notes ----- The voltage is scaled relative to the adaptive threshold before passing through the surrogate function, providing a normalized input that improves gradient stability. """ V = self.V.value if V is None else V if V_th is None: V_th = self.omega + self.V_th_1.value + self.V_th_2.value # Scale relative to the effective adaptive threshold. v_scaled = (V - V_th) / u.math.abs(self.omega - self.E_L) return self.spk_fun(v_scaled)
[docs] def update(self, x=0. * u.pA, spike_delta=None): r"""Advance the neuron state by one time step. Implements the NEST-compatible update order for the MAT2 model with exact integration of subthreshold dynamics. Parameters ---------- x : Quantity, ArrayLike, optional External input current (pA) for this time step (default: 0 pA). Broadcastable to population shape. This current is buffered and applied in the **next** time step (one-step delay). spike_delta : Quantity, optional Instantaneous spike-weight input (pA) to add to synaptic currents. When provided, bypasses ``sum_delta_inputs()`` — useful for JIT-compiled ``brainstate.transform.for_loop`` simulations where delta inputs are pre-computed as a JAX array indexed by step. Positive values go to ``i_syn_ex``; negative values go to ``i_syn_in``. Returns ------- spike : ArrayLike Differentiable spike signal for this time step. Shape matches population size. Notes ----- **Update sequence (NEST-compatible):** 1. Integrate membrane potential using exact propagators 2. Decay adaptive threshold components (V_th_1, V_th_2) 3. Decay synaptic currents and add incoming spike weights 4. Detect spikes: if not refractory and V_m >= V_th, emit spike 5. On spike: jump threshold components, reset refractory counter 6. Buffer external current for next step **Key implementation details:** - Membrane potential is **never reset** (non-resetting LIF) - Spike detection compares V_m against the adaptive threshold V_th = ω + V_th_1 + V_th_2 - Refractory period is implemented as an integer countdown; no voltage clamping - External current ``x`` is stored in ``i_0`` and applied in the **next** time step **Numerical stability:** The exact integration scheme requires tau_m ≠ tau_syn_ex and tau_m ≠ tau_syn_in. Violations of this constraint are caught during initialization. """ t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Extract state variables as dimensionless JAX arrays (JIT-compatible) V_rel = u.math.asarray(self.V.value / u.mV, dtype=dftype) - self._E_L_mV V_th_1 = u.math.asarray(self.V_th_1.value / u.mV, dtype=dftype) V_th_2 = u.math.asarray(self.V_th_2.value / u.mV, dtype=dftype) i_syn_ex = u.math.asarray(self.i_syn_ex.value / u.pA, dtype=dftype) i_syn_in = u.math.asarray(self.i_syn_in.value / u.pA, dtype=dftype) i_0 = u.math.asarray(self.i_0.value / u.pA, dtype=dftype) r = self.refractory_step_count.value # --- Get spike inputs --- if spike_delta is not None: w_all = u.math.asarray(spike_delta / u.pA, dtype=dftype) else: w_all = u.math.asarray(self.sum_delta_inputs(0. * u.pA) / u.pA, dtype=dftype) w_ex = jnp.where(w_all > 0.0, w_all, jnp.zeros_like(w_all)) w_in = jnp.where(w_all < 0.0, w_all, jnp.zeros_like(w_all)) # --- Get current inputs (one-step delayed, stored for next step) --- i_0_next = jnp.broadcast_to( u.math.asarray(self.sum_current_inputs(x, self.V.value) / u.pA, dtype=dftype), self.varshape, ) # === NEST update ordering (mat2_psc_exp.cpp lines 316-358) === # Step 1: Evolve membrane potential using pre-computed propagators V_rel = (V_rel * self._P22_expm1 + V_rel + i_syn_ex * self._P21ex + i_syn_in * self._P21in + (self._I_e_pA + i_0) * self._P20) # Step 2: Evolve adaptive threshold V_th_1 = V_th_1 * self._P11th V_th_2 = V_th_2 * self._P22th # Step 3: Decay synaptic currents and add incoming spikes i_syn_ex = i_syn_ex * self._P11ex + w_ex i_syn_in = i_syn_in * self._P11in + w_in # Step 4-5: Spike detection (no voltage reset!) not_refractory = r == 0 spike_cond = not_refractory & (V_rel >= self._omega_rel_mV + V_th_1 + V_th_2) # On spike: jump threshold components, set refractory counter V_th_1 = jnp.where(spike_cond, V_th_1 + self._alpha_1_mV, V_th_1) V_th_2 = jnp.where(spike_cond, V_th_2 + self._alpha_2_mV, V_th_2) r = jnp.where( spike_cond, self.ref_count, jnp.where(not_refractory, r, r - 1), ).astype(ditype) # --- Write back state variables --- self.V.value = (V_rel + self._E_L_mV) * u.mV self.V_th_1.value = V_th_1 * u.mV self.V_th_2.value = V_th_2 * u.mV self.i_syn_ex.value = i_syn_ex * u.pA self.i_syn_in.value = i_syn_in * u.pA self.i_0.value = i_0_next * u.pA self.refractory_step_count.value = r self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t + dt_q, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) # Return spike output via surrogate gradient V_th_abs = self._omega_rel_mV + V_th_1 + V_th_2 + self._E_L_mV V_out = jnp.where(spike_cond, V_th_abs + 1e-12, V_th_abs - 1e-12) return self.get_spike(V_out * u.mV, V_th_abs * u.mV)