Source code for brainpy_state._nest.glif_psc_double_alpha

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

r"""Current-based generalized leaky integrate-and-fire (GLIF) neuron model
with double alpha-function shaped synaptic currents.

This module implements the ``glif_psc_double_alpha`` neuron model from the
NEST simulator. It extends the ``glif_psc`` model by using a double alpha
function (fast + slow components) for synaptic currents, allowing more
flexible control over the synaptic current waveform shape and tail.

The implementation uses exact integration (propagator matrices) matching
NEST's numerical scheme for linear subthreshold dynamics.

References
----------
.. [1] Teeter C, Iyer R, Menon V, Gouwens N, Feng D, Berg J, Szafer A,
       Cain N, Zeng H, Hawrylycz M, Koch C, & Mihalas S (2018).
       Generalized leaky integrate-and-fire models classify multiple neuron
       types. Nature Communications 9:709.
.. [2] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
       the large, fluctuating synaptic conductance state typical of
       neocortical neurons in vivo. J. Comput. Neurosci. 16:159-175.
.. [3] NEST Simulator ``glif_psc_double_alpha`` model documentation and
       C++ source: ``models/glif_psc_double_alpha.h`` and
       ``models/glif_psc_double_alpha.cpp``.
"""

from typing import Callable, Sequence

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer, alpha_propagator_p31_p32

__all__ = [
    'glif_psc_double_alpha',
]


class glif_psc_double_alpha(NESTNeuron):
    r"""Current-based generalized leaky integrate-and-fire (GLIF) neuron model
    with double alpha-function shaped synaptic currents.

    Implements the NEST ``glif_psc_double_alpha`` model, which extends the basic
    GLIF framework [1]_ with dual-component (fast + slow) alpha-function shaped
    postsynaptic currents [2]_. This allows flexible control over synaptic
    waveform shape, including realistic biphasic or long-tailed currents observed
    experimentally. The model provides five GLIF variants (Models 1-5) selectable
    via boolean flags, ranging from simple LIF to adaptive threshold models with
    after-spike currents.

    **Model Family Overview**

    The five GLIF models are hierarchical, each adding biological mechanisms:

    * **GLIF Model 1** (LIF) — Traditional leaky integrate-and-fire
    * **GLIF Model 2** (LIF_R) — LIF with biologically defined reset rules
    * **GLIF Model 3** (LIF_ASC) — LIF with after-spike currents
    * **GLIF Model 4** (LIF_R_ASC) — LIF with reset rules and after-spike currents
    * **GLIF Model 5** (LIF_R_ASC_A) — LIF with reset rules, after-spike currents,
      and a voltage-dependent threshold

    Model selection is determined by three boolean parameters:

    +--------+---------------------------+----------------------+--------------------+
    | Model  | spike_dependent_threshold | after_spike_currents | adapting_threshold |
    +========+===========================+======================+====================+
    | GLIF1  | False                     | False                | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF2  | True                      | False                | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF3  | False                     | True                 | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF4  | True                      | True                 | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF5  | True                      | True                 | True               |
    +--------+---------------------------+----------------------+--------------------+

    **Double Alpha-Function Synaptic Currents**

    Each synaptic receptor port receives inputs shaped by a sum of two alpha
    functions (fast and slow components) [2]_:

    .. math::

        I_\mathrm{syn,k}(t) = \alpha_\mathrm{fast}(t; \tau_{\mathrm{syn,fast},k})
                            + \mathrm{amp\_slow}_k \cdot
                              \alpha_\mathrm{slow}(t; \tau_{\mathrm{syn,slow},k})

    Normalization: A spike of weight 1.0 produces a peak current of 1 pA for the
    fast component at :math:`t = \tau_\mathrm{syn,fast}`. The slow component peaks
    at :math:`\mathrm{amp\_slow}_k` pA at :math:`t = \tau_\mathrm{syn,slow}`.

    Multiple receptor ports are supported by passing arrays to ``tau_syn_fast``,
    ``tau_syn_slow``, and ``amp_slow``. By default, one receptor port is created.
    Projections specify receptor ports via ``receptor_<k>`` labels (0-based indexing).

    **Detailed Mathematical Description**

    **1. Membrane Dynamics**

    The membrane potential :math:`U` (tracked relative to :math:`E_L`) evolves
    via exact integration using propagator matrices:

    .. math::

       U(t+dt) = U(t) \cdot P_{33} + (I_e + I_\mathrm{stim} + I_\mathrm{ASC,sum}) \cdot P_{30}
                 + \sum_k \left( P_{31,k}^\mathrm{fast} \cdot y_{1,k}^\mathrm{fast}
                               + P_{32,k}^\mathrm{fast} \cdot y_{2,k}^\mathrm{fast} \right)
                 + \sum_k \left( P_{31,k}^\mathrm{slow} \cdot y_{1,k}^\mathrm{slow}
                               + P_{32,k}^\mathrm{slow} \cdot y_{2,k}^\mathrm{slow} \right)

    where:

    .. math::

       P_{33} = \exp\left(-\frac{dt}{\tau_m}\right), \quad
       P_{30} = \frac{\tau_m}{C_m} \left(1 - P_{33}\right), \quad
       \tau_m = \frac{C_m}{g}

    The propagators :math:`P_{31,k}`, :math:`P_{32,k}` for each receptor and
    component (fast/slow) are computed using the ``IAFPropagatorAlpha`` algorithm,
    which handles the singularity when :math:`\tau_m \approx \tau_{\mathrm{syn},k}`.

    **2. Synaptic Current Dynamics (Double Alpha Function)**

    Each receptor port :math:`k` maintains **four** state variables: two for the
    fast component and two for the slow component. Each pair :math:`(y_1, y_2)`
    represents an alpha function.

    **Fast component:**

    .. math::

       y_{2,k}^\mathrm{fast}(t+dt) = P_{21,k}^\mathrm{fast} \cdot y_{1,k}^\mathrm{fast}(t)
                                   + P_{22,k}^\mathrm{fast} \cdot y_{2,k}^\mathrm{fast}(t)

    .. math::

       y_{1,k}^\mathrm{fast}(t+dt) = P_{11,k}^\mathrm{fast} \cdot y_{1,k}^\mathrm{fast}(t)

    **Slow component:**

    .. math::

       y_{2,k}^\mathrm{slow}(t+dt) = P_{21,k}^\mathrm{slow} \cdot y_{1,k}^\mathrm{slow}(t)
                                   + P_{22,k}^\mathrm{slow} \cdot y_{2,k}^\mathrm{slow}(t)

    .. math::

       y_{1,k}^\mathrm{slow}(t+dt) = P_{11,k}^\mathrm{slow} \cdot y_{1,k}^\mathrm{slow}(t)

    where:

    .. math::

       P_{11,k}^\mathrm{fast} = P_{22,k}^\mathrm{fast} = \exp\left(-\frac{dt}{\tau_{\mathrm{syn,fast},k}}\right)

       P_{21,k}^\mathrm{fast} = dt \cdot P_{11,k}^\mathrm{fast}

       P_{11,k}^\mathrm{slow} = P_{22,k}^\mathrm{slow} = \exp\left(-\frac{dt}{\tau_{\mathrm{syn,slow},k}}\right)

       P_{21,k}^\mathrm{slow} = dt \cdot P_{11,k}^\mathrm{slow}

    On a presynaptic spike of weight :math:`w` to receptor port :math:`k`:

    .. math::

       y_{1,k}^\mathrm{fast} \leftarrow y_{1,k}^\mathrm{fast} + w \cdot \frac{e}{\tau_{\mathrm{syn,fast},k}}

       y_{1,k}^\mathrm{slow} \leftarrow y_{1,k}^\mathrm{slow} + w \cdot \frac{e}{\tau_{\mathrm{syn,slow},k}} \cdot \mathrm{amp\_slow}_k

    The total synaptic current is:

    .. math::

       I_\mathrm{syn,total} = \sum_k \left( y_{2,k}^\mathrm{fast} + y_{2,k}^\mathrm{slow} \right)

    **3. After-Spike Currents (GLIF3/4/5)**

    After-spike currents (ASC) model slow adaptation via exponentially decaying
    currents triggered by spikes. Each ASC component :math:`I_j` decays with rate
    :math:`k_j`:

    .. math::

       I_j(t+dt) = I_j(t) \cdot \exp(-k_j \cdot dt)

    The time-averaged ASC over a step (used for stable integration) is:

    .. math::

       \bar{I}_j = \frac{1 - \exp(-k_j \cdot dt)}{k_j \cdot dt} \cdot I_j(t)

    On spike, ASC values are reset:

    .. math::

       I_j \leftarrow \Delta I_j + I_j \cdot r_j \cdot \exp(-k_j \cdot t_\mathrm{ref})

    where :math:`\Delta I_j` is the jump amplitude and :math:`r_j` is the fraction
    coefficient.

    **4. Spike-Dependent Threshold (GLIF2/4/5)**

    The spike component of the threshold :math:`\theta_s` decays exponentially:

    .. math::

       \theta_s(t+dt) = \theta_s(t) \cdot \exp(-b_s \cdot dt)

    On spike, after refractory decay:

    .. math::

       \theta_s \leftarrow \theta_s \cdot \exp(-b_s \cdot t_\mathrm{ref}) + \Delta\theta_s

    Voltage reset uses a biologically defined rule:

    .. math::

       U \leftarrow f_v \cdot U_\mathrm{old} + V_\mathrm{add}

    where :math:`f_v` is the voltage fraction coefficient and :math:`V_\mathrm{add}`
    is the additive constant.

    **5. Voltage-Dependent Threshold (GLIF5)**

    The voltage component of the threshold :math:`\theta_v` evolves according to:

    .. math::

       \theta_v(t+dt) = \phi \cdot (U_\mathrm{old} - \beta) \cdot P_\mathrm{decay}
           + \frac{1}{P_{\theta,v}} \cdot \left(\theta_v(t)
               - \phi \cdot (U_\mathrm{old} - \beta)
               - \frac{a_v}{b_v} \cdot \beta \right)
           + \frac{a_v}{b_v} \cdot \beta

    where:

    .. math::

       \phi = \frac{a_v}{b_v - g/C_m}, \quad
       P_\mathrm{decay} = \exp\left(-\frac{g \cdot dt}{C_m}\right), \quad
       P_{\theta,v} = \exp(b_v \cdot dt), \quad
       \beta = \frac{I_e + I_\mathrm{stim} + I_\mathrm{ASC,sum}}{g}

    **6. Overall Threshold and Spike Condition**

    .. math::

       \theta_\mathrm{total} = \theta_\infty + \theta_s + \theta_v

    Spike condition (checked after voltage update):

    .. math::

       \text{spike} = \begin{cases}
       \text{True} & \text{if } U > \theta_\mathrm{total} \\
       \text{False} & \text{otherwise}
       \end{cases}

    **7. Numerical Integration and Update Order**

    The discrete-time update sequence per simulation step is:

    1. Record :math:`U_\mathrm{old}` (relative to :math:`E_L`).
    2. If not refractory:

       a. Decay spike threshold component :math:`\theta_s`.
       b. Compute time-averaged ASC :math:`\bar{I}_j` and decay ASC values.
       c. Update membrane potential :math:`U` (include fast/slow synaptic contributions).
       d. Compute voltage-dependent threshold component :math:`\theta_v` (using :math:`U_\mathrm{old}`).
       e. Update total threshold :math:`\theta_\mathrm{total}`.
       f. If :math:`U > \theta_\mathrm{total}`: emit spike, apply reset rules.

    3. If refractory: decrement refractory counter, hold :math:`U` at :math:`U_\mathrm{old}`.
    4. Update synaptic current state variables for both fast and slow components.
    5. Add incoming spike current jumps (scaled for fast/slow).
    6. Buffer external current input for next step.
    7. Save :math:`U_\mathrm{old}` for next step.

    Parameters
    ----------
    in_size : int, tuple of int
        Population shape (number of neurons). Scalars are interpreted as (n,).
    g : ArrayLike, optional
        Membrane (leak) conductance. Default: 9.43 nS. Must be strictly positive.
        Shape: scalar or broadcastable to ``in_size``.
    E_L : ArrayLike, optional
        Resting (leak) membrane potential (absolute). Default: -78.85 mV.
        Shape: scalar or broadcastable to ``in_size``.
    V_th : ArrayLike, optional
        Instantaneous spike threshold (absolute). Default: -51.68 mV.
        Must be greater than ``V_reset``. Shape: scalar or broadcastable to ``in_size``.
    C_m : ArrayLike, optional
        Membrane capacitance. Default: 58.72 pF. Must be strictly positive.
        Shape: scalar or broadcastable to ``in_size``.
    t_ref : ArrayLike, optional
        Absolute refractory period. Default: 3.75 ms. Must be strictly positive.
        Shape: scalar or broadcastable to ``in_size``.
    V_reset : ArrayLike, optional
        Reset potential (absolute; used in GLIF1/3). Default: -78.85 mV.
        Must be less than ``V_th``. Shape: scalar or broadcastable to ``in_size``.
    th_spike_add : float, optional
        Threshold additive constant after spike (:math:`\Delta\theta_s`).
        Default: 0.37 mV. Used in GLIF2/4/5.
    th_spike_decay : float, optional
        Spike threshold decay rate (:math:`b_s`). Default: 0.009 /ms.
        Must be strictly positive. Used in GLIF2/4/5.
    voltage_reset_fraction : float, optional
        Voltage fraction coefficient after spike (:math:`f_v`).
        Default: 0.20. Must be in [0.0, 1.0]. Used in GLIF2/4/5.
    voltage_reset_add : float, optional
        Voltage additive constant after spike (:math:`V_\mathrm{add}`).
        Default: 18.51 mV. Used in GLIF2/4/5.
    th_voltage_index : float, optional
        Voltage-dependent threshold leak rate (:math:`a_v`). Default: 0.005 /ms.
        Used in GLIF5.
    th_voltage_decay : float, optional
        Voltage-dependent threshold decay rate (:math:`b_v`). Default: 0.09 /ms.
        Must be strictly positive. Used in GLIF5.
    asc_init : Sequence[float], optional
        Initial values of after-spike current components (pA). Default: (0.0, 0.0).
        Length must match ``asc_decay``, ``asc_amps``, ``asc_r``. Used in GLIF3/4/5.
    asc_decay : Sequence[float], optional
        After-spike current decay rates (:math:`k_j`, /ms). Default: (0.003, 0.1).
        All values must be strictly positive. Used in GLIF3/4/5.
    asc_amps : Sequence[float], optional
        After-spike current jump amplitudes (:math:`\Delta I_j`, pA). Default: (-9.18, -198.94).
        Used in GLIF3/4/5.
    asc_r : Sequence[float], optional
        After-spike current fraction coefficients (:math:`r_j`). Default: (1.0, 1.0).
        All values must be in [0.0, 1.0]. Used in GLIF3/4/5.
    tau_syn_fast : Sequence[float], optional
        Fast synaptic alpha-function time constants (ms). Default: (2.0,).
        All values must be strictly positive. Length determines number of receptor ports.
    tau_syn_slow : Sequence[float], optional
        Slow synaptic alpha-function time constants (ms). Default: (6.0,).
        All values must be strictly positive. Length must match ``tau_syn_fast``.
    amp_slow : Sequence[float], optional
        Relative amplitude of slow component (unitless). Default: (0.3,).
        All values must be strictly positive. Length must match ``tau_syn_fast``.
    spike_dependent_threshold : bool, optional
        Enable biologically defined reset rules (GLIF2/4/5). Default: False.
    after_spike_currents : bool, optional
        Enable after-spike currents (GLIF3/4/5). Default: False.
    adapting_threshold : bool, optional
        Enable voltage-dependent threshold (GLIF5). Default: False.
    I_e : ArrayLike, optional
        Constant external current input (pA). Default: 0.0 pA.
        Shape: scalar or broadcastable to ``in_size``.
    V_initializer : Callable, optional
        Membrane potential initializer. Default: ``Constant(E_L)``.
        Should return values in mV when called with shape and batch_size.
    spk_fun : Callable, optional
        Surrogate gradient function for differentiable spike generation.
        Default: ``ReluGrad()``. Must accept scaled voltage and return spike output.
    spk_reset : str, optional
        Spike reset mode. Default: ``'hard'`` (stop gradient). Alternative: ``'soft'``.
    name : str, optional
        Name of this neuron population.


    Parameter Mapping
    -----------------

    =============================== ========================== ========================================== =====================================================
    **Parameter**                   **Default**                **Math equivalent**                        **Description**
    =============================== ========================== ========================================== =====================================================
    ``in_size``                     (required)                 —                                          Population shape
    ``g``                           9.43 nS                    :math:`g`                                  Membrane (leak) conductance
    ``E_L``                         -78.85 mV                  :math:`E_L`                                Resting membrane potential
    ``V_th``                        -51.68 mV                  :math:`V_\mathrm{th}`                      Instantaneous threshold (absolute)
    ``C_m``                         58.72 pF                   :math:`C_m`                                Membrane capacitance
    ``t_ref``                       3.75 ms                    :math:`t_\mathrm{ref}`                     Absolute refractory period
    ``V_reset``                     -78.85 mV                  :math:`V_\mathrm{reset}`                   Reset potential (absolute; GLIF1/3)
    ``th_spike_add``                0.37 mV                    :math:`\Delta\theta_s`                     Threshold additive constant after spike
    ``th_spike_decay``              0.009 /ms                  :math:`b_s`                                Spike threshold decay rate
    ``voltage_reset_fraction``      0.20                       :math:`f_v`                                Voltage fraction after spike
    ``voltage_reset_add``           18.51 mV                   :math:`V_\mathrm{add}`                     Voltage additive after spike
    ``th_voltage_index``            0.005 /ms                  :math:`a_v`                                Voltage-dependent threshold leak
    ``th_voltage_decay``            0.09 /ms                   :math:`b_v`                                Voltage-dependent threshold decay rate
    ``asc_init``                    (0.0, 0.0) pA              :math:`I_j(0)`                             Initial values of ASC
    ``asc_decay``                   (0.003, 0.1) /ms           :math:`k_j`                                ASC decay rates
    ``asc_amps``                    (-9.18, -198.94) pA        :math:`\Delta I_j`                         ASC amplitudes on spike
    ``asc_r``                       (1.0, 1.0)                 :math:`r_j`                                ASC fraction coefficient
    ``tau_syn_fast``                (2.0,) ms                  :math:`\tau_{\mathrm{syn,fast},k}`         Fast synaptic alpha-function time constants
    ``tau_syn_slow``                (6.0,) ms                  :math:`\tau_{\mathrm{syn,slow},k}`         Slow synaptic alpha-function time constants
    ``amp_slow``                    (0.3,)                     :math:`\mathrm{amp\_slow}_k`               Relative amplitude of slow component
    ``spike_dependent_threshold``   False                      —                                          Enable biologically defined reset (GLIF2/4/5)
    ``after_spike_currents``        False                      —                                          Enable after-spike currents (GLIF3/4/5)
    ``adapting_threshold``          False                      —                                          Enable voltage-dependent threshold (GLIF5)
    ``I_e``                         0.0 pA                     :math:`I_e`                                Constant external current
    ``V_initializer``               Constant(E_L)              —                                          Membrane potential initializer
    ``spk_fun``                     ReluGrad()                 —                                          Surrogate spike function
    ``spk_reset``                   ``'hard'``                 —                                          Reset mode (``'hard'`` or ``'soft'``)
    =============================== ========================== ========================================== =====================================================

    Notes
    -----
    - **Default parameters** are from GLIF Model 5 of Cell 490626718 in the
      Allen Cell Type Database (https://celltypes.brain-map.org).
    - **Voltage tracking**: ``V_th`` and ``V_reset`` are specified in absolute mV.
      Internally, membrane potential is stored relative to ``E_L`` (matching NEST).
    - **Stability constraint** for GLIF2/4/5: The reset condition should satisfy:

      .. math::

          E_L + f_v \cdot (V_\mathrm{th} - E_L) + V_\mathrm{add} < V_\mathrm{th} + \Delta\theta_s

      Otherwise, the neuron may spike continuously.
    - **Numerical integration**: Uses exact integration via propagator matrices
      (matching NEST), unlike ``glif_cond`` which uses RKF45 ODE integration.
    - **Singularity handling**: If :math:`\tau_m \approx \tau_{\mathrm{syn,fast}}`
      or :math:`\tau_m \approx \tau_{\mathrm{syn,slow}}`, the model automatically
      applies singularity-safe formulas (see NEST IAF_Integration_Singularity notebook).
    - **Synaptic waveform control**: The double alpha function provides more flexible
      control over synaptic current shape compared to single alpha (``glif_psc``).
      By tuning ``tau_syn_fast``, ``tau_syn_slow``, and ``amp_slow``, experimentally
      observed waveforms can be matched.
    - **Receptor port indexing**: Synaptic inputs are registered via
      ``add_delta_input()`` with labels like ``'receptor_0'``, ``'receptor_1'``, etc.
      Inputs without explicit receptor labels default to receptor 0.
    - **State persistence**: After-spike current values (``_ASCurrents``), threshold
      components (``_threshold_spike``, ``_threshold_voltage``), and total threshold
      (``_threshold``) are stored as NumPy arrays (not JAX arrays) to match NEST's
      state handling and allow in-place updates during the per-neuron loop.

    Examples
    --------
    **1. GLIF Model 1 (basic LIF) with single receptor:**

    .. code-block:: python

        >>> import brainpy.state as st
        >>> import brainstate as bst
        >>> import saiunit as u
        >>> bst.environ.set(dt=0.1 * u.ms)
        >>> neurons = st.glif_psc_double_alpha(
        ...     in_size=100,
        ...     g=10.0 * u.nS,
        ...     E_L=-70.0 * u.mV,
        ...     V_th=-55.0 * u.mV,
        ...     C_m=250.0 * u.pF,
        ...     t_ref=2.0 * u.ms,
        ...     V_reset=-70.0 * u.mV,
        ...     tau_syn_fast=(2.0,) * u.ms,
        ...     tau_syn_slow=(6.0,) * u.ms,
        ...     amp_slow=(0.5,),
        ... )
        >>> neurons.init_all_states()
        >>> spikes = neurons.update(10.0 * u.pA)

    **2. GLIF Model 5 (full model) with multiple receptors:**

    .. code-block:: python

        >>> neurons = st.glif_psc_double_alpha(
        ...     in_size=50,
        ...     spike_dependent_threshold=True,
        ...     after_spike_currents=True,
        ...     adapting_threshold=True,
        ...     tau_syn_fast=(1.0, 3.0) * u.ms,  # Two receptor ports
        ...     tau_syn_slow=(5.0, 10.0) * u.ms,
        ...     amp_slow=(0.3, 0.4),
        ...     asc_decay=(0.01, 0.05) / u.ms,
        ...     asc_amps=(-10.0, -100.0) * u.pA,
        ... )
        >>> neurons.init_all_states()
        >>> # Synaptic inputs can target different receptors
        >>> neurons.add_delta_input('excitatory_receptor_0')
        >>> neurons.add_delta_input('inhibitory_receptor_1')

    **3. Accessing synaptic current components:**

    .. code-block:: python

        >>> I_syn_total = neurons.get_I_syn()
        >>> I_syn_fast = neurons.get_I_syn_fast()
        >>> I_syn_slow = neurons.get_I_syn_slow()

    References
    ----------
    .. [1] Teeter C, Iyer R, Menon V, Gouwens N, Feng D, Berg J, Szafer A,
           Cain N, Zeng H, Hawrylycz M, Koch C, & Mihalas S (2018).
           Generalized leaky integrate-and-fire models classify multiple neuron
           types. Nature Communications 9:709.
           DOI: 10.1038/s41467-017-02717-4
    .. [2] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
           the large, fluctuating synaptic conductance state typical of
           neocortical neurons in vivo. J. Comput. Neurosci. 16:159-175.
           DOI: 10.1023/B:JCNS.0000014108.03012.81
    .. [3] NEST Simulator ``glif_psc_double_alpha`` model documentation and C++
           source: ``models/glif_psc_double_alpha.h`` and
           ``models/glif_psc_double_alpha.cpp`` in NEST repository.

    See Also
    --------
    glif_psc : Single alpha-function variant.
    glif_cond : Conductance-based GLIF using ODE integration.
    gif_psc_exp_multisynapse : Generalized IF with exponential PSCs and multisynapse support.
    aeif_psc_alpha : Adaptive exponential IF with alpha PSCs.
    """
    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        g: ArrayLike = 9.43 * u.nS,
        E_L: ArrayLike = -78.85 * u.mV,
        V_th: ArrayLike = -51.68 * u.mV,
        C_m: ArrayLike = 58.72 * u.pF,
        t_ref: ArrayLike = 3.75 * u.ms,
        V_reset: ArrayLike = -78.85 * u.mV,
        th_spike_add: float = 0.37,  # mV
        th_spike_decay: float = 0.009,  # 1/ms
        voltage_reset_fraction: float = 0.20,
        voltage_reset_add: float = 18.51,  # mV
        th_voltage_index: float = 0.005,  # 1/ms
        th_voltage_decay: float = 0.09,  # 1/ms
        asc_init: Sequence[float] = (0.0, 0.0),  # pA
        asc_decay: Sequence[float] = (0.003, 0.1),  # 1/ms
        asc_amps: Sequence[float] = (-9.18, -198.94),  # pA
        asc_r: Sequence[float] = (1.0, 1.0),
        tau_syn_fast: Sequence[float] = (2.0,),  # ms
        tau_syn_slow: Sequence[float] = (6.0,),  # ms
        amp_slow: Sequence[float] = (0.3,),  # unitless
        spike_dependent_threshold: bool = False,
        after_spike_currents: bool = False,
        adapting_threshold: bool = False,
        I_e: ArrayLike = 0.0 * u.pA,
        gsl_error_tol: ArrayLike = 1e-6,
        V_initializer: Callable = None,
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # Store membrane parameters
        self.g_m = braintools.init.param(g, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)

        # V_th and V_reset are absolute; store th_inf_ relative to E_L (like NEST)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)

        # Scalar GLIF parameters (unitless floats in NEST units)
        self.th_spike_add = float(th_spike_add)
        self.th_spike_decay = float(th_spike_decay)
        self.voltage_reset_fraction = float(voltage_reset_fraction)
        self.voltage_reset_add = float(voltage_reset_add)
        self.th_voltage_index = float(th_voltage_index)
        self.th_voltage_decay = float(th_voltage_decay)

        # ASC parameters (lists of floats)
        self.asc_init = tuple(float(x) for x in asc_init)
        self.asc_decay = tuple(float(x) for x in asc_decay)
        self.asc_amps = tuple(float(x) for x in asc_amps)
        self.asc_r = tuple(float(x) for x in asc_r)

        # Synaptic parameters (double alpha: fast and slow components)
        self.tau_syn_fast = tuple(float(x) for x in tau_syn_fast)
        self.tau_syn_slow = tuple(float(x) for x in tau_syn_slow)
        self.amp_slow = tuple(float(x) for x in amp_slow)

        # Model mechanism flags
        self.has_theta_spike = bool(spike_dependent_threshold)
        self.has_asc = bool(after_spike_currents)
        self.has_theta_voltage = bool(adapting_threshold)

        # Default V_initializer to E_L
        if V_initializer is None:
            V_initializer = braintools.init.Constant(E_L)
        self.V_initializer = V_initializer

        self._n_receptors = len(self.tau_syn_fast)
        self.gsl_error_tol = gsl_error_tol

        self._validate_parameters()

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    @property
    def n_receptors(self):
        r"""Number of synaptic receptor ports.

        Returns the number of distinct receptor ports configured for this neuron
        population. Each receptor port has independent fast and slow alpha-function
        current dynamics, allowing modeling of multiple synaptic receptor types
        (e.g., AMPA, NMDA, GABA_A, GABA_B).

        Returns
        -------
        int
            Number of receptor ports, determined by the length of ``tau_syn_fast``
            (which must match the lengths of ``tau_syn_slow`` and ``amp_slow``).

        Notes
        -----
        - Receptor ports are indexed from 0 to ``n_receptors - 1``.
        - Projections target specific receptors via labels like ``'receptor_0'``,
          ``'receptor_1'``, etc.
        - By default (single-element arrays for synaptic parameters), ``n_receptors == 1``.

        See Also
        --------
        _collect_receptor_delta_inputs : Routes synaptic inputs to receptor ports.
        """
        return self._n_receptors

    def _validate_parameters(self):
        # Check valid model mechanism combinations
        s, a, v = self.has_theta_spike, self.has_asc, self.has_theta_voltage
        valid_combos = [
            (False, False, False),  # GLIF1
            (True, False, False),  # GLIF2
            (False, True, False),  # GLIF3
            (True, True, False),  # GLIF4
            (True, True, True),  # GLIF5
        ]
        if (s, a, v) not in valid_combos:
            raise ValueError(
                "Incorrect model mechanism combination. "
                "Valid combinations: GLIF1(FFF), GLIF2(TFF), GLIF3(FTF), "
                "GLIF4(TTF), GLIF5(TTT). Got spike_dependent_threshold=%s, "
                "after_spike_currents=%s, adapting_threshold=%s." % (s, a, v)
            )

        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return

        # V_reset (relative) < V_th (relative) — both relative to E_L
        E_L_val = self.E_L
        V_reset_rel = self.V_reset - E_L_val
        V_th_rel = self.V_th - E_L_val
        if np.any(V_reset_rel >= V_th_rel):
            raise ValueError("Reset potential must be smaller than threshold.")

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError("Capacitance must be strictly positive.")
        if np.any(self.g_m <= 0.0 * u.nS):
            raise ValueError("Membrane conductance must be strictly positive.")
        if np.any(self.t_ref <= 0.0 * u.ms):
            raise ValueError("Refractory time constant must be strictly positive.")

        if self.has_theta_spike:
            if self.th_spike_decay <= 0.0:
                raise ValueError("Spike induced threshold time constant must be strictly positive.")
            if not (0.0 <= self.voltage_reset_fraction <= 1.0):
                raise ValueError("Voltage fraction coefficient following spike must be within [0.0, 1.0].")

        if self.has_asc:
            n = len(self.asc_decay)
            if not (len(self.asc_init) == n and len(self.asc_amps) == n and len(self.asc_r) == n):
                raise ValueError(
                    "All after spike current parameters (asc_init, asc_decay, asc_amps, asc_r) "
                    "must have the same size."
                )
            for k_val in self.asc_decay:
                if k_val <= 0.0:
                    raise ValueError("After-spike current time constant must be strictly positive.")
            for r_val in self.asc_r:
                if not (0.0 <= r_val <= 1.0):
                    raise ValueError(
                        "After spike current fraction coefficients r must be within [0.0, 1.0]."
                    )

        if self.has_theta_voltage:
            if self.th_voltage_decay <= 0.0:
                raise ValueError("Voltage-induced threshold time constant must be strictly positive.")

        # Check synaptic parameter sizes
        n_rec = len(self.tau_syn_fast)
        if len(self.tau_syn_slow) != n_rec:
            raise ValueError(
                f"tau_syn_slow must have same length as tau_syn_fast ({n_rec}), "
                f"got {len(self.tau_syn_slow)}."
            )
        if len(self.amp_slow) != n_rec:
            raise ValueError(
                f"amp_slow must have same length as tau_syn_fast ({n_rec}), "
                f"got {len(self.amp_slow)}."
            )

        for tau in self.tau_syn_fast:
            if tau <= 0.0:
                raise ValueError("All fast synaptic time constants must be strictly positive.")
        for tau in self.tau_syn_slow:
            if tau <= 0.0:
                raise ValueError("All slow synaptic time constants must be strictly positive.")
        for amp in self.amp_slow:
            if amp <= 0.0:
                raise ValueError("All slow synaptic amplitudes must be strictly positive.")

        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize all state variables for the neuron population. Creates and initializes all state variables required for GLIF dynamics, including membrane potential, synaptic current states (fast and slow components for each receptor port), threshold components, after-spike current values, refractory counters, and buffered input current. This method is compatible with ``brainstate.transform.for_loop``: all GLIF-specific state variables are stored as JAX ``HiddenState`` arrays, and pre-computed decay constants are stored as Python floats. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() dt_ms = float(np.asarray(u.get_mantissa(dt / u.ms))) V = braintools.init.param(self.V_initializer, self.varshape) self.V = brainstate.HiddenState(V) # Per-receptor alpha-function current states: fast component (y1_fast rate pA/ms, y2_fast current pA) self.y1_fast = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.pA / u.ms), self.varshape) ) for _ in range(self._n_receptors) ] self.y2_fast = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.pA), self.varshape) ) for _ in range(self._n_receptors) ] # Per-receptor alpha-function current states: slow component (y1_slow rate pA/ms, y2_slow current pA) self.y1_slow = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.pA / u.ms), self.varshape) ) for _ in range(self._n_receptors) ] self.y2_slow = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.pA), self.varshape) ) for _ in range(self._n_receptors) ] self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) # GLIF-specific state as HiddenState (JAX-traceable, compatible with for_loop) n_asc = len(self.asc_decay) self._asc_states = [ brainstate.HiddenState(jnp.full(self.varshape, self.asc_init[a], dtype=dftype)) for a in range(n_asc) ] # Threshold components (relative to E_L) as HiddenState E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV))) th_inf = float(np.asarray(u.get_mantissa(self.V_th / u.mV))) - E_L_mV self._th_inf = th_inf self._threshold_spike_state = brainstate.HiddenState( jnp.zeros(self.varshape, dtype=dftype) ) self._threshold_voltage_state = brainstate.HiddenState( jnp.zeros(self.varshape, dtype=dftype) ) self._threshold_state = brainstate.HiddenState( jnp.full(self.varshape, th_inf, dtype=dftype) ) # Pre-compute decay rates (Python float constants, computed once per init_state call) G = float(np.asarray(u.get_mantissa(self.g_m / u.nS))) C_m_val = float(np.asarray(u.get_mantissa(self.C_m / u.pF))) t_ref_ms = float(np.asarray(u.get_mantissa(self.t_ref / u.ms))) if self.has_theta_spike: self._decay_spike = np.exp(-self.th_spike_decay * dt_ms) self._decay_spike_refr = np.exp(-self.th_spike_decay * t_ref_ms) if self.has_asc: self._asc_decay_rates = [np.exp(-self.asc_decay[a] * dt_ms) for a in range(n_asc)] self._asc_stable_coeff = [ ((1.0 / self.asc_decay[a]) / dt_ms) * (1.0 - self._asc_decay_rates[a]) for a in range(n_asc) ] self._asc_refr_decay_rates = [ self.asc_r[a] * np.exp(-self.asc_decay[a] * t_ref_ms) for a in range(n_asc) ] if self.has_theta_voltage: self._potential_decay_rate = np.exp(-G * dt_ms / C_m_val) self._theta_voltage_decay_rate_inv = 1.0 / np.exp(self.th_voltage_decay * dt_ms) self._phi = self.th_voltage_index / (self.th_voltage_decay - G / C_m_val) self._abpara_ratio = self.th_voltage_index / self.th_voltage_decay # Pre-compute exact propagator matrices (NEST IAFPropagatorAlpha scheme) tau_m = C_m_val / G # membrane time constant in ms self._P33 = np.exp(-dt_ms / tau_m) self._P30 = (1.0 / C_m_val) * (1.0 - self._P33) * tau_m # mV/pA self._P11_fast = [] self._P21_fast = [] self._P22_fast = [] self._P31_fast = [] self._P32_fast = [] self._PSCInitialValues_fast = [] self._P11_slow = [] self._P21_slow = [] self._P22_slow = [] self._P31_slow = [] self._P32_slow = [] self._PSCInitialValues_slow = [] for k in range(self._n_receptors): # Fast component p11_f = np.exp(-dt_ms / self.tau_syn_fast[k]) self._P11_fast.append(p11_f) self._P22_fast.append(p11_f) self._P21_fast.append(dt_ms * p11_f) p31_f, p32_f = alpha_propagator_p31_p32(self.tau_syn_fast[k], tau_m, C_m_val, dt_ms) self._P31_fast.append(float(p31_f)) self._P32_fast.append(float(p32_f)) self._PSCInitialValues_fast.append(np.e / self.tau_syn_fast[k]) # Slow component p11_s = np.exp(-dt_ms / self.tau_syn_slow[k]) self._P11_slow.append(p11_s) self._P22_slow.append(p11_s) self._P21_slow.append(dt_ms * p11_s) p31_s, p32_s = alpha_propagator_p31_p32(self.tau_syn_slow[k], tau_m, C_m_val, dt_ms) self._P31_slow.append(float(p31_s)) self._P32_slow.append(float(p32_s)) self._PSCInitialValues_slow.append(np.e / self.tau_syn_slow[k] * self.amp_slow[k])
# Backward-compatible properties for threshold components @property def _threshold(self): return self._threshold_state.value @property def _threshold_spike(self): return self._threshold_spike_state.value @property def _threshold_voltage(self): return self._threshold_voltage_state.value
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute spike output from membrane potential using surrogate gradient function. Applies the surrogate gradient function (``self.spk_fun``) to a scaled version of the membrane potential to produce a differentiable spike signal. This method computes the spike output **without updating state**, making it useful for inspection or custom integration schemes. **Scaling** The membrane potential is scaled to the range where the surrogate function is most sensitive: .. math:: v_\mathrm{scaled} = \frac{V - V_\mathrm{th}}{V_\mathrm{th} - V_\mathrm{reset}} This normalization ensures that: - When :math:`V = V_\mathrm{th}`, :math:`v_\mathrm{scaled} = 0` - When :math:`V = V_\mathrm{reset}`, :math:`v_\mathrm{scaled} = -1` Parameters ---------- V : ArrayLike, optional Membrane potential (absolute, in mV). If ``None`` (default), uses the current state ``self.V.value``. If provided, should have shape compatible with ``self.varshape`` (or ``(batch_size, *self.varshape)``). Unit: ``saiunit.mV`` or dimensionless (interpreted as mV). Returns ------- spike : jax.Array Spike output computed via surrogate gradient function. Shape: same as input ``V``. Dtype: same as input ``V`` (typically ``jnp.float32``). Values: Continuous in [0, 1] for most surrogate functions (e.g., ``ReluGrad``, ``SigmoidGrad``), though exact range depends on ``self.spk_fun``. Notes ----- - This method is used internally by ``update()`` to compute spike output after the membrane potential update. - The surrogate gradient function ensures gradients can flow through spike events during backpropagation, enabling gradient-based training of spiking neural networks. - The scaling factor :math:`(V_\mathrm{th} - V_\mathrm{reset})` normalizes the input to the surrogate function, improving numerical stability and gradient flow. See Also -------- update : Main simulation step, which calls this method internally. braintools.surrogate.ReluGrad : Default surrogate gradient function. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _collect_receptor_delta_inputs(self): r"""Collect delta inputs per receptor port using label-based routing. Returns a list of current jumps (pA) for each receptor port, JIT-compatible. """ dftype = brainstate.environ.dftype() return [ self.sum_delta_inputs( jnp.zeros(self.varshape, dtype=dftype) * u.pA, label=f'receptor_{k}', ) for k in range(self._n_receptors) ]
[docs] def update(self, x=0.0 * u.pA): r"""Perform a single simulation step using exact propagator matrices. Implements the NEST ``glif_psc_double_alpha`` update using the exact IAFPropagatorAlpha integration scheme. All GLIF-specific discrete updates (threshold decay, ASC, voltage-dependent threshold) are applied as vectorised JAX operations, making this method compatible with ``brainstate.transform.for_loop``. Parameters ---------- x : ArrayLike, optional External current input (pA), applied with one-step delay. Default: 0.0 pA. Returns ------- spike : jax.Array Binary spike tensor (float32), shape ``(*varshape)``. """ t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Python-level constants (concrete, not JAX-traced) E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV))) I_e_pA = float(np.asarray(u.get_mantissa(self.I_e / u.pA))) V_reset_rel = float(np.asarray(u.get_mantissa(self.V_reset / u.mV))) - E_L_mV G_nS = float(np.asarray(u.get_mantissa(self.g_m / u.nS))) # JAX state (traced under for_loop) r = self.refractory_step_count.value # int array, varshape i_stim_pA = u.get_mantissa(self.I_stim.value / u.pA) # float array, varshape # V_rel (old, before this step's update) V_rel = jax.lax.stop_gradient( u.get_mantissa(self.V.value / u.mV) - E_L_mV ) # plain JAX array, mV relative to E_L # Buffer new external current (one-step delay) new_i_stim_q = self.sum_current_inputs(x, self.V.value) is_refractory = r > 0 i_ext = I_e_pA + i_stim_pA # pA, plain JAX array # ---- Pre-integration GLIF updates (vectorised JAX) ---- n_asc = len(self.asc_decay) # 1. Spike threshold decay (non-refractory only) if self.has_theta_spike: tspk = self._threshold_spike_state.value tspk = jnp.where(is_refractory, tspk, tspk * self._decay_spike) else: tspk = jnp.zeros(self.varshape, dtype=dftype) # 2. ASC stable-coeff sum + decay (non-refractory only) if self.has_asc: asc_sum_new = jnp.zeros(self.varshape, dtype=dftype) asc_decayed = [] for a in range(n_asc): asc_a = self._asc_states[a].value asc_sum_new = asc_sum_new + self._asc_stable_coeff[a] * asc_a asc_decayed.append(asc_a * self._asc_decay_rates[a]) asc_sum = jnp.where(is_refractory, jnp.zeros(self.varshape, dtype=dftype), asc_sum_new) else: asc_sum = jnp.zeros(self.varshape, dtype=dftype) asc_decayed = [] # 3. Voltage-dependent threshold (non-refractory only, using old V_rel) if self.has_theta_voltage: tvlt = self._threshold_voltage_state.value beta = (i_ext + asc_sum) / G_nS # pA/nS = mV tvlt_new = ( self._phi * (V_rel - beta) * self._potential_decay_rate + self._theta_voltage_decay_rate_inv * ( tvlt - self._phi * (V_rel - beta) - self._abpara_ratio * beta ) + self._abpara_ratio * beta ) tvlt = jnp.where(is_refractory, tvlt, tvlt_new) else: tvlt = jnp.zeros(self.varshape, dtype=dftype) # 4. Total threshold threshold = tspk + tvlt + self._th_inf # ---- Exact propagator update for V and y1/y2 fast/slow ---- # Read y1/y2 old values (stripped of units → plain floats matching NEST convention) y1f_old = [u.get_mantissa(self.y1_fast[k].value / (u.pA / u.ms)) for k in range(self._n_receptors)] y2f_old = [u.get_mantissa(self.y2_fast[k].value / u.pA) for k in range(self._n_receptors)] y1s_old = [u.get_mantissa(self.y1_slow[k].value / (u.pA / u.ms)) for k in range(self._n_receptors)] y2s_old = [u.get_mantissa(self.y2_slow[k].value / u.pA) for k in range(self._n_receptors)] # 5. V update via exact propagator v_new = V_rel * self._P33 + (i_ext + asc_sum) * self._P30 for k in range(self._n_receptors): v_new = (v_new + self._P31_fast[k] * y1f_old[k] + self._P32_fast[k] * y2f_old[k] + self._P31_slow[k] * y1s_old[k] + self._P32_slow[k] * y2s_old[k]) # Clamp refractory neurons to old V_rel v_new = jnp.where(is_refractory, V_rel, v_new) # 6. Spike check (non-refractory only, uses threshold from step 4) spiked = (v_new > threshold) & ~is_refractory # 7. ASC reset on spike if self.has_asc: for a in range(n_asc): asc_a = self._asc_states[a].value asc_reset = self.asc_amps[a] + asc_decayed[a] * self._asc_refr_decay_rates[a] self._asc_states[a].value = jnp.where( spiked, asc_reset, jnp.where(is_refractory, asc_a, asc_decayed[a]) ) # 8. Voltage reset on spike if not self.has_theta_spike: # GLIF1/3: simple reset V_final_rel = jnp.where(spiked, V_reset_rel, v_new) else: # GLIF2/4/5: biologically defined reset V_reset_bio = self.voltage_reset_fraction * V_rel + self.voltage_reset_add V_final_rel = jnp.where(spiked, V_reset_bio, v_new) # 9. Theta_spike reset on spike tspk_reset = tspk * self._decay_spike_refr + self.th_spike_add tspk = jnp.where(spiked, tspk_reset, tspk) threshold = jnp.where(spiked, tspk + tvlt + self._th_inf, threshold) # 10. Refractory counter r_new = jnp.where( spiked, self.ref_count, jnp.where(is_refractory, r - 1, r) ) # 11. Y1/Y2 propagator update (unconditional — all neurons, including refractory) y1f_new = [self._P11_fast[k] * y1f_old[k] for k in range(self._n_receptors)] y2f_new = [self._P21_fast[k] * y1f_old[k] + self._P22_fast[k] * y2f_old[k] for k in range(self._n_receptors)] y1s_new = [self._P11_slow[k] * y1s_old[k] for k in range(self._n_receptors)] y2s_new = [self._P21_slow[k] * y1s_old[k] + self._P22_slow[k] * y2s_old[k] for k in range(self._n_receptors)] # 12. Collect and apply synaptic delta inputs to y1 dy_input = self._collect_receptor_delta_inputs() for k in range(self._n_receptors): w_k = u.get_mantissa(dy_input[k] / u.pA) # weight in pA y1f_new[k] = y1f_new[k] + self._PSCInitialValues_fast[k] * w_k y1s_new[k] = y1s_new[k] + self._PSCInitialValues_slow[k] * w_k # ---- Write back all state ---- self.V.value = (V_final_rel + E_L_mV) * u.mV for k in range(self._n_receptors): self.y1_fast[k].value = y1f_new[k] * (u.pA / u.ms) self.y2_fast[k].value = y2f_new[k] * u.pA self.y1_slow[k].value = y1s_new[k] * (u.pA / u.ms) self.y2_slow[k].value = y2s_new[k] * u.pA self._threshold_spike_state.value = tspk self._threshold_voltage_state.value = tvlt self._threshold_state.value = threshold self.refractory_step_count.value = jnp.asarray(r_new, dtype=ditype) self.I_stim.value = new_i_stim_q + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spiked, t + dt_q, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) return jnp.asarray(spiked, dtype=jnp.float32)
[docs] def get_I_syn(self): r"""Get the total synaptic current summed across all receptor ports and components. Computes the instantaneous total synaptic current by summing the fast and slow alpha-function current components (``y2_fast`` and ``y2_slow``) across all receptor ports. This represents the total postsynaptic current :math:`I_\mathrm{syn}` flowing into the membrane at the current time step. Returns ------- I_syn : jax.Array Total synaptic current across all receptors (fast + slow). Shape: same as ``self.V.value.shape`` (including batch dimension if present). Unit: ``saiunit.pA`` (picoamperes). Notes ----- - This method reads the current state of ``y2_fast[k]`` and ``y2_slow[k]`` for all receptor ports ``k``, without modifying state. - For a population with :math:`N` receptor ports, the total current is: .. math:: I_\mathrm{syn,total} = \sum_{k=0}^{N-1} \left( y_{2,k}^\mathrm{fast} + y_{2,k}^\mathrm{slow} \right) See Also -------- get_I_syn_fast : Get only the fast component. get_I_syn_slow : Get only the slow component. """ I_syn = 0.0 * u.pA for k in range(self._n_receptors): I_syn = I_syn + self.y2_fast[k].value + self.y2_slow[k].value return I_syn
[docs] def get_I_syn_fast(self): r"""Get the fast component of synaptic current summed across all receptor ports. Computes the instantaneous fast synaptic current by summing the fast alpha-function current components (``y2_fast``) across all receptor ports. The fast component corresponds to synaptic currents with time constant ``tau_syn_fast``. Returns ------- I_syn_fast : jax.Array Fast synaptic current across all receptors. Shape: same as ``self.V.value.shape`` (including batch dimension if present). Unit: ``saiunit.pA`` (picoamperes). Notes ----- - For a population with :math:`N` receptor ports, the fast current is: .. math:: I_\mathrm{syn,fast} = \sum_{k=0}^{N-1} y_{2,k}^\mathrm{fast} See Also -------- get_I_syn : Get total synaptic current (fast + slow). get_I_syn_slow : Get only the slow component. """ I_syn = 0.0 * u.pA for k in range(self._n_receptors): I_syn = I_syn + self.y2_fast[k].value return I_syn
[docs] def get_I_syn_slow(self): r"""Get the slow component of synaptic current summed across all receptor ports. Computes the instantaneous slow synaptic current by summing the slow alpha-function current components (``y2_slow``) across all receptor ports. The slow component corresponds to synaptic currents with time constant ``tau_syn_slow``, scaled by amplitude factor ``amp_slow``. Returns ------- I_syn_slow : jax.Array Slow synaptic current across all receptors. Shape: same as ``self.V.value.shape`` (including batch dimension if present). Unit: ``saiunit.pA`` (picoamperes). Notes ----- - For a population with :math:`N` receptor ports, the slow current is: .. math:: I_\mathrm{syn,slow} = \sum_{k=0}^{N-1} y_{2,k}^\mathrm{slow} - The slow component typically models NMDA-like or other slow synaptic processes. See Also -------- get_I_syn : Get total synaptic current (fast + slow). get_I_syn_fast : Get only the fast component. """ I_syn = 0.0 * u.pA for k in range(self._n_receptors): I_syn = I_syn + self.y2_slow[k].value return I_syn