Source code for brainpy_state._nest.glif_cond

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable, Sequence

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'glif_cond',
]


class glif_cond(NESTNeuron):
    r"""Conductance-based generalized leaky integrate-and-fire (GLIF) neuron model.

    Implements the five-level GLIF model hierarchy from Teeter et al. (2018) [1]_,
    with conductance-based alpha-function synapses and adaptive RKF45 integration.
    Designed for fitting to Allen Institute single-neuron electrophysiology data.
    Supports multiple receptor ports with distinct reversal potentials and synaptic
    time constants.

    **Model Selection**

    The five GLIF variants are:

    1. **GLIF1 (LIF)** — Traditional leaky integrate-and-fire
    2. **GLIF2 (LIF_R)** — LIF with biologically defined voltage reset rules
    3. **GLIF3 (LIF_ASC)** — LIF with after-spike currents (adaptation)
    4. **GLIF4 (LIF_R_ASC)** — LIF with reset rules and after-spike currents
    5. **GLIF5 (LIF_R_ASC_A)** — LIF with reset rules, after-spike currents, and
       voltage-dependent threshold

    Model mechanism selection is controlled by three boolean parameters:

    +--------+---------------------------+----------------------+--------------------+
    | Model  | spike_dependent_threshold | after_spike_currents | adapting_threshold |
    +========+===========================+======================+====================+
    | GLIF1  | False                     | False                | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF2  | True                      | False                | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF3  | False                     | True                 | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF4  | True                      | True                 | False              |
    +--------+---------------------------+----------------------+--------------------+
    | GLIF5  | True                      | True                 | True               |
    +--------+---------------------------+----------------------+--------------------+

    Mathematical Formulation
    ------------------------

    **1. Membrane Dynamics**

    The membrane potential :math:`V` (tracked relative to :math:`E_L` internally)
    evolves according to:

    .. math::

       C_\mathrm{m} \frac{dV}{dt} = -g \cdot V
           - \sum_k g_k(t) \left( V + E_L - E_{\mathrm{rev},k} \right)
           + I_\mathrm{e} + I_\mathrm{ASC,sum}

    where:

    * :math:`g` — membrane (leak) conductance
    * :math:`g_k(t)` — synaptic conductance for receptor port :math:`k`
    * :math:`E_{\mathrm{rev},k}` — reversal potential for port :math:`k`
    * :math:`I_\mathrm{e}` — constant external current
    * :math:`I_\mathrm{ASC,sum}` — sum of after-spike currents (GLIF3/4/5 only)

    **2. Synaptic Conductances (Alpha Function)**

    Each receptor port :math:`k` has a conductance modeled by an alpha function
    with two state variables :math:`dg_k` and :math:`g_k`:

    .. math::

       \frac{d(dg_k)}{dt} = -\frac{dg_k}{\tau_{\mathrm{syn},k}}

    .. math::

       \frac{dg_k}{dt} = dg_k - \frac{g_k}{\tau_{\mathrm{syn},k}}

    On a presynaptic spike with weight :math:`w`, the derivative is incremented:

    .. math::

       dg_k \leftarrow dg_k + w \cdot \frac{e}{\tau_{\mathrm{syn},k}}

    This normalization ensures that a spike of weight 1.0 produces a peak conductance
    of 1 nS at time :math:`t = \tau_{\mathrm{syn},k}`.

    **3. After-Spike Currents (GLIF3/4/5)**

    After-spike currents (ASC) model spike-triggered adaptation as exponentially
    decaying currents. Each ASC component :math:`I_j` decays with rate :math:`k_j`:

    .. math::

       I_j(t+dt) = I_j(t) \cdot \exp(-k_j \cdot dt)

    The time-averaged ASC over a simulation step uses the exact integral (stable
    coefficient method):

    .. math::

       \bar{I}_j = \frac{1 - \exp(-k_j \cdot dt)}{k_j \cdot dt} \cdot I_j(t)

    On spike, ASC values are updated with amplitude and refractory decay:

    .. math::

       I_j \leftarrow \Delta I_j + I_j \cdot r_j \cdot \exp(-k_j \cdot t_\mathrm{ref})

    where :math:`\Delta I_j` is the amplitude jump and :math:`r_j \in [0, 1]` is
    the retention fraction.

    **4. Spike-Dependent Threshold (GLIF2/4/5)**

    The spike component of the threshold :math:`\theta_s` decays exponentially:

    .. math::

       \theta_s(t+dt) = \theta_s(t) \cdot \exp(-b_s \cdot dt)

    On spike, after accounting for refractory decay, it is incremented:

    .. math::

       \theta_s \leftarrow \theta_s \cdot \exp(-b_s \cdot t_\mathrm{ref})
           + \Delta\theta_s

    Voltage reset with spike-dependent threshold uses:

    .. math::

       V \leftarrow f_v \cdot V_\mathrm{old} + V_\mathrm{add}

    where :math:`f_v \in [0, 1]` is the fraction coefficient and :math:`V_\mathrm{add}`
    is the additive term (both in mV, dimensionless in NEST convention).

    **5. Voltage-Dependent Threshold (GLIF5)**

    The voltage component :math:`\theta_v` evolves according to:

    .. math::

       \theta_v(t+dt) = \phi \cdot (V_\mathrm{old} - \beta) \cdot P_\mathrm{decay}
           + \frac{1}{P_{\theta,v}} \cdot \left(\theta_v(t)
               - \phi \cdot (V_\mathrm{old} - \beta)
               - \frac{a_v}{b_v} \cdot \beta \right)
           + \frac{a_v}{b_v} \cdot \beta

    where:

    * :math:`\phi = a_v / (b_v - g/C_m)`
    * :math:`P_\mathrm{decay} = \exp(-g \cdot dt / C_m)`
    * :math:`P_{\theta,v} = \exp(b_v \cdot dt)`
    * :math:`\beta = (I_e + I_\mathrm{ASC,sum}) / g`

    The total threshold is the sum of all components:

    .. math::

       \theta = \theta_\infty + \theta_s + \theta_v

    Spike condition (checked after ODE integration):

    .. math::

       V > \theta

    **Numerical Integration**

    The ODE system :math:`[V, dg_0, g_0, dg_1, g_1, \ldots]` is integrated using
    an adaptive RKF45(4,5) Runge-Kutta-Fehlberg method with error tolerance
    ``ATOL = 1e-3`` and minimum step size ``MIN_H = 1e-8`` ms, matching NEST's
    GSL integrator behavior.

    **Update Order (Per Simulation Step)**

    1. Record :math:`V_\mathrm{old}` (relative to :math:`E_L`)
    2. Integrate ODE system over :math:`(t, t+dt]` using RKF45
    3. If not refractory:

       a. Decay spike threshold component :math:`\theta_s`
       b. Compute time-averaged ASC :math:`\bar{I}_\mathrm{ASC,sum}` and decay ASC values
       c. Compute voltage-dependent threshold :math:`\theta_v` (using :math:`V_\mathrm{old}`)
       d. Update total threshold :math:`\theta = \theta_\infty + \theta_s + \theta_v`
       e. If :math:`V > \theta`: emit spike, apply reset rules

    4. If refractory: decrement counter, clamp :math:`V` to :math:`V_\mathrm{old}`
    5. Add incoming spike conductance jumps (scaled by :math:`e/\tau_\mathrm{syn}`)
    6. Update external current buffer :math:`I_\mathrm{stim}`
    7. Save :math:`V_\mathrm{old}` for next step

    Parameters
    ----------
    in_size : Size
        Shape of the neuron population. Can be an int for 1D or tuple for multi-D.
    g : ArrayLike, optional
        Membrane (leak) conductance in nS. Broadcast to population shape.
        Default: 9.43 nS (from Allen Cell 490626718 GLIF5).
    E_L : ArrayLike, optional
        Resting membrane potential (leak reversal) in mV. Default: -78.85 mV.
    V_th : ArrayLike, optional
        Instantaneous spike threshold (absolute) in mV. Default: -51.68 mV.
        Internally, threshold is tracked relative to ``E_L``.
    C_m : ArrayLike, optional
        Membrane capacitance in pF. Must be strictly positive. Default: 58.72 pF.
    t_ref : ArrayLike, optional
        Absolute refractory period in ms. During this period, voltage is clamped
        and spike detection is disabled. Must be > 0. Default: 3.75 ms.
    V_reset : ArrayLike, optional
        Reset potential (absolute) in mV for GLIF1/3 models. Ignored if
        ``spike_dependent_threshold=True``. Default: -78.85 mV (same as ``E_L``).
    th_spike_add : float, optional
        Threshold additive constant :math:`\Delta\theta_s` after spike (mV,
        dimensionless in NEST units). Only used if ``spike_dependent_threshold=True``.
        Default: 0.37 mV.
    th_spike_decay : float, optional
        Spike threshold decay rate :math:`b_s` in 1/ms. Must be > 0 if
        ``spike_dependent_threshold=True``. Default: 0.009 /ms.
    voltage_reset_fraction : float, optional
        Voltage fraction coefficient :math:`f_v \in [0, 1]` after spike.
        Only used if ``spike_dependent_threshold=True``. Default: 0.20.
    voltage_reset_add : float, optional
        Voltage additive term :math:`V_\mathrm{add}` after spike (mV, dimensionless).
        Only used if ``spike_dependent_threshold=True``. Default: 18.51 mV.
    th_voltage_index : float, optional
        Voltage-dependent threshold leak :math:`a_v` in 1/ms. Only used if
        ``adapting_threshold=True``. Default: 0.005 /ms.
    th_voltage_decay : float, optional
        Voltage-dependent threshold decay rate :math:`b_v` in 1/ms. Must be > 0 if
        ``adapting_threshold=True``. Default: 0.09 /ms.
    asc_init : Sequence[float], optional
        Initial values of after-spike currents in pA. Tuple/list of length ``n_asc``.
        Default: (0.0, 0.0) pA.
    asc_decay : Sequence[float], optional
        ASC decay rates :math:`k_j` in 1/ms. All values must be > 0. Length must
        match ``asc_init``. Default: (0.003, 0.1) /ms.
    asc_amps : Sequence[float], optional
        ASC amplitude jumps :math:`\Delta I_j` on spike, in pA. Length must match
        ``asc_init``. Negative values cause hyperpolarizing adaptation. Default:
        (-9.18, -198.94) pA.
    asc_r : Sequence[float], optional
        ASC retention fraction coefficients :math:`r_j \in [0, 1]`. Length must
        match ``asc_init``. Default: (1.0, 1.0).
    tau_syn : Sequence[float], optional
        Synaptic alpha-function time constants :math:`\tau_{\mathrm{syn},k}` in ms,
        one per receptor port. All values must be > 0. Default: (0.2, 2.0) ms
        (fast excitatory, slow inhibitory).
    E_rev : Sequence[float], optional
        Synaptic reversal potentials :math:`E_{\mathrm{rev},k}` in mV, one per
        receptor port. Must have same length as ``tau_syn``. Default: (0.0, -85.0) mV
        (excitatory, inhibitory).
    spike_dependent_threshold : bool, optional
        Enable biologically defined voltage reset rules (GLIF2/4/5). Default: False.
    after_spike_currents : bool, optional
        Enable after-spike currents (adaptation) (GLIF3/4/5). Default: False.
    adapting_threshold : bool, optional
        Enable voltage-dependent threshold component (GLIF5 only). Requires
        ``spike_dependent_threshold=True`` and ``after_spike_currents=True``.
        Default: False.
    I_e : ArrayLike, optional
        Constant external current in pA. Broadcast to population shape. Default: 0.0 pA.
    V_initializer : Callable, optional
        Initializer for membrane potential. If None, defaults to ``Constant(E_L)``.
    spk_fun : Callable, optional
        Surrogate gradient function for spike generation. Default: ``ReluGrad()``.
    spk_reset : str, optional
        Spike reset mode: ``'hard'`` (stop gradient) or ``'soft'`` (subtract threshold).
        Default: ``'hard'``.
    name : str, optional
        Name of the neuron population.


    Parameter Mapping
    -----------------

    =============================== =================== ========================================== =====================================================
    **Parameter**                   **Default**         **Math equivalent**                        **Description**
    =============================== =================== ========================================== =====================================================
    ``in_size``                     (required)                                                     Population shape
    ``g``                           9.43 nS             :math:`g`                                  Membrane (leak) conductance
    ``E_L``                         -78.85 mV           :math:`E_L`                                Resting membrane potential
    ``V_th``                        -51.68 mV           :math:`V_\mathrm{th}`                      Instantaneous threshold (absolute)
    ``C_m``                         58.72 pF            :math:`C_\mathrm{m}`                       Membrane capacitance
    ``t_ref``                       3.75 ms             :math:`t_\mathrm{ref}`                     Absolute refractory period
    ``V_reset``                     -78.85 mV           :math:`V_\mathrm{reset}`                   Reset potential (absolute; GLIF1/3)
    ``th_spike_add``                0.37 mV             :math:`\Delta\theta_s`                     Threshold additive constant after spike
    ``th_spike_decay``              0.009 /ms           :math:`b_s`                                Spike threshold decay rate
    ``voltage_reset_fraction``      0.20                :math:`f_v`                                Voltage fraction after spike
    ``voltage_reset_add``           18.51 mV            :math:`V_\mathrm{add}`                     Voltage additive after spike
    ``th_voltage_index``            0.005 /ms           :math:`a_v`                                Voltage-dependent threshold leak
    ``th_voltage_decay``            0.09 /ms            :math:`b_v`                                Voltage-dependent threshold decay rate
    ``asc_init``                    (0.0, 0.0) pA                                                  Initial values of ASC
    ``asc_decay``                   (0.003, 0.1) /ms    :math:`k_j`                                ASC time constants (decay rates)
    ``asc_amps``                    (-9.18, -198.94) pA :math:`\Delta I_j`                         ASC amplitudes on spike
    ``asc_r``                       (1.0, 1.0)          :math:`r_j`                                ASC fraction coefficient
    ``tau_syn``                     (0.2, 2.0) ms       :math:`\tau_{\mathrm{syn},k}`              Synaptic alpha-function time constants
    ``E_rev``                       (0.0, -85.0) mV     :math:`E_{\mathrm{rev},k}`                 Synaptic reversal potentials
    ``spike_dependent_threshold``   False                                                          Enable biologically defined reset (GLIF2/4/5)
    ``after_spike_currents``        False                                                          Enable after-spike currents (GLIF3/4/5)
    ``adapting_threshold``          False                                                          Enable voltage-dependent threshold (GLIF5)
    ``I_e``                         0.0 pA              :math:`I_e`                                Constant external current
    ``V_initializer``               Constant(E_L)                                                  Membrane potential initializer
    ``spk_fun``                     ReluGrad()                                                     Surrogate spike function
    ``spk_reset``                   ``'hard'``                                                     Reset mode
    =============================== =================== ========================================== =====================================================

    Attributes
    ----------
    V : HiddenState
        Membrane potential in mV (absolute, broadcast to population shape).
    g_syn : list[HiddenState]
        Synaptic conductances :math:`g_k` in nS, one per receptor port.
    dg_syn : list[HiddenState]
        Synaptic conductance derivatives :math:`dg_k` in nS, one per receptor port.
    last_spike_time : ShortTermState
        Time of last spike in ms.
    refractory_step_count : ShortTermState
        Remaining refractory steps (int32), decremented each step.
    integration_step : ShortTermState
        Internal RKF45 adaptive step size in ms (updated per neuron).
    I_stim : ShortTermState
        Buffered external current in pA (applied with one-step delay).

    Notes
    -----
    **Implementation Details**

    * **Internal state convention**: Membrane potential is tracked relative to ``E_L``
      internally (matching NEST), but exposed as absolute values in mV.
    * **Threshold components**: ``_threshold_spike``, ``_threshold_voltage``, and
      ``_th_inf`` are stored as numpy arrays (not JAX) for exact NEST replication.
    * **After-spike currents**: ``_ASCurrents`` is a numpy array of shape
      ``(n_asc, *in_size)``.
    * **Receptor port routing**: Delta inputs (from projections) with keys containing
      ``'receptor_<k>'`` (0-based) are routed to receptor port ``k``. Inputs without
      a receptor tag default to receptor 0.
    * **Stability constraint**: For GLIF2/4/5, the reset condition must satisfy:

      .. math::

          E_L + f_v \cdot (V_\mathrm{th} - E_L) + V_\mathrm{add} < V_\mathrm{th} + \Delta\theta_s

      Otherwise the neuron may spike continuously.

    * **Valid mechanism combinations**: Only the five combinations listed in the
      parameter table are valid. Other combinations will raise ``ValueError``.
    * **Adaptive integration**: RKF45 step size adapts per-neuron and is preserved
      across simulation steps.

    **Failure Modes**


    * Raises ``ValueError`` if parameter validation fails (invalid model combination,
      non-positive capacitance/conductance/time constants, mismatched sequence lengths).
    * Raises ``ValueError`` if ``V_reset >= V_th`` (relative to ``E_L``).
    * Integration may fail to converge if ``dt`` is too large relative to ``tau_syn``
      or if threshold parameters cause continuous spiking.

    **Default Parameters**

    Default parameter values are from GLIF Model 5 of Cell 490626718 from the
    `Allen Cell Type Database <https://celltypes.brain-map.org>`_, fitted to
    mouse visual cortex layer 5 pyramidal neuron electrophysiology.

    References
    ----------
    .. [1] Teeter C, Iyer R, Menon V, Gouwens N, Feng D, Berg J, Szafer A,
           Cain N, Zeng H, Hawrylycz M, Koch C, & Mihalas S (2018).
           Generalized leaky integrate-and-fire models classify multiple neuron
           types. Nature Communications 9:709.
           DOI: `10.1038/s41467-017-02717-4 <https://doi.org/10.1038/s41467-017-02717-4>`_
    .. [2] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
           the large, fluctuating synaptic conductance state typical of
           neocortical neurons in vivo. J. Comput. Neurosci. 16:159-175.
           DOI: `10.1023/B:JCNS.0000014108.03012.81 <https://doi.org/10.1023/B:JCNS.0000014108.03012.81>`_
    .. [3] NEST Simulator ``glif_cond`` model documentation and C++ source:
           ``models/glif_cond.h`` and ``models/glif_cond.cpp``.

    Examples
    --------
    **Example 1: GLIF1 (simple LIF) with dual-receptor synapses**

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> import brainstate as bts
        >>> # Create GLIF1 neuron (all mechanisms disabled)
        >>> neuron = bst.glif_cond(
        ...     100,
        ...     spike_dependent_threshold=False,
        ...     after_spike_currents=False,
        ...     adapting_threshold=False,
        ...     tau_syn=(0.2, 2.0),  # fast excitatory, slow inhibitory
        ...     E_rev=(0.0, -85.0)   # mV
        ... )
        >>> neuron.init_all_states()
        >>> # Stimulate with constant current
        >>> with bts.environ.context(dt=0.1 * u.ms):
        ...     for _ in range(100):
        ...         spike = neuron(200.0 * u.pA)

    **Example 2: GLIF5 (full model) with custom parameters**

    .. code-block:: python

        >>> # Create GLIF5 with all mechanisms enabled
        >>> neuron = bst.glif_cond(
        ...     (10, 10),  # 10x10 population
        ...     spike_dependent_threshold=True,
        ...     after_spike_currents=True,
        ...     adapting_threshold=True,
        ...     g=10.0 * u.nS,
        ...     C_m=100.0 * u.pF,
        ...     tau_syn=(0.5, 1.5, 5.0),  # three receptor ports
        ...     E_rev=(0.0, 0.0, -80.0)   # two excitatory, one inhibitory
        ... )
        >>> neuron.init_all_states()
        >>> print(neuron.n_receptors)  # 3

    **Example 3: Multi-receptor input routing**

    .. code-block:: python

        >>> from brainevent.nn import FixedProb
        >>> # Create projection targeting receptor 1
        >>> proj = bst.align_post_projection(
        ...     pre=pre_neurons,
        ...     post=glif_neurons,
        ...     comm=FixedProb(0.1, weight=0.5 * u.nS),
        ...     label='receptor_1'  # route to receptor port 1
        ... )

    See Also
    --------
    iaf_cond_exp : Simpler conductance-based LIF with exponential synapses
    gif_cond_exp_multisynapse : Generalized integrate-and-fire with exponential synapses
    glif_psc : Current-based GLIF variant
    """
    __module__ = 'brainpy.state'

    _ATOL = 1e-3
    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 10000

    def __init__(
        self,
        in_size: Size,
        g: ArrayLike = 9.43 * u.nS,
        E_L: ArrayLike = -78.85 * u.mV,
        V_th: ArrayLike = -51.68 * u.mV,
        C_m: ArrayLike = 58.72 * u.pF,
        t_ref: ArrayLike = 3.75 * u.ms,
        V_reset: ArrayLike = -78.85 * u.mV,
        th_spike_add: float = 0.37,  # mV
        th_spike_decay: float = 0.009,  # 1/ms
        voltage_reset_fraction: float = 0.20,
        voltage_reset_add: float = 18.51,  # mV
        th_voltage_index: float = 0.005,  # 1/ms
        th_voltage_decay: float = 0.09,  # 1/ms
        asc_init: Sequence[float] = (0.0, 0.0),  # pA
        asc_decay: Sequence[float] = (0.003, 0.1),  # 1/ms
        asc_amps: Sequence[float] = (-9.18, -198.94),  # pA
        asc_r: Sequence[float] = (1.0, 1.0),
        tau_syn: Sequence[float] = (0.2, 2.0),  # ms
        E_rev: Sequence[float] = (0.0, -85.0),  # mV
        spike_dependent_threshold: bool = False,
        after_spike_currents: bool = False,
        adapting_threshold: bool = False,
        I_e: ArrayLike = 0.0 * u.pA,
        V_initializer: Callable = None,
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # Store membrane parameters
        self.g_m = braintools.init.param(g, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)

        # V_th and V_reset are absolute; store th_inf_ relative to E_L (like NEST)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)

        # Scalar GLIF parameters (unitless floats in NEST units)
        self.th_spike_add = float(th_spike_add)
        self.th_spike_decay = float(th_spike_decay)
        self.voltage_reset_fraction = float(voltage_reset_fraction)
        self.voltage_reset_add = float(voltage_reset_add)
        self.th_voltage_index = float(th_voltage_index)
        self.th_voltage_decay = float(th_voltage_decay)

        # ASC parameters (lists of floats)
        self.asc_init = tuple(float(x) for x in asc_init)
        self.asc_decay = tuple(float(x) for x in asc_decay)
        self.asc_amps = tuple(float(x) for x in asc_amps)
        self.asc_r = tuple(float(x) for x in asc_r)

        # Synaptic parameters (lists)
        self.tau_syn = tuple(float(x) for x in tau_syn)
        self.E_rev = tuple(float(x) for x in E_rev)

        # Model mechanism flags
        self.has_theta_spike = bool(spike_dependent_threshold)
        self.has_asc = bool(after_spike_currents)
        self.has_theta_voltage = bool(adapting_threshold)

        # Default V_initializer to E_L
        if V_initializer is None:
            V_initializer = braintools.init.Constant(E_L)
        self.V_initializer = V_initializer

        self._n_receptors = len(self.tau_syn)

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self._ATOL,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    @property
    def n_receptors(self):
        r"""Number of synaptic receptor ports.

        Returns
        -------
        int
            Number of receptor ports, determined by length of ``tau_syn``.
        """
        return self._n_receptors

    def _validate_parameters(self):
        # Check valid model mechanism combinations
        s, a, v = self.has_theta_spike, self.has_asc, self.has_theta_voltage
        valid_combos = [
            (False, False, False),  # GLIF1
            (True, False, False),  # GLIF2
            (False, True, False),  # GLIF3
            (True, True, False),  # GLIF4
            (True, True, True),  # GLIF5
        ]
        if (s, a, v) not in valid_combos:
            raise ValueError(
                "Incorrect model mechanism combination. "
                "Valid combinations: GLIF1(FFF), GLIF2(TFF), GLIF3(FTF), "
                "GLIF4(TTF), GLIF5(TTT). Got spike_dependent_threshold=%s, "
                "after_spike_currents=%s, adapting_threshold=%s." % (s, a, v)
            )

        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return

        # V_reset (relative) < V_th (relative) — both relative to E_L
        E_L_val = self.E_L
        V_reset_rel = self.V_reset - E_L_val
        V_th_rel = self.V_th - E_L_val
        if np.any(V_reset_rel >= V_th_rel):
            raise ValueError("Reset potential must be smaller than threshold.")

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError("Capacitance must be strictly positive.")
        if np.any(self.g_m <= 0.0 * u.nS):
            raise ValueError("Membrane conductance must be strictly positive.")
        if np.any(self.t_ref <= 0.0 * u.ms):
            raise ValueError("Refractory time constant must be strictly positive.")

        if self.has_theta_spike:
            if self.th_spike_decay <= 0.0:
                raise ValueError("Spike induced threshold time constant must be strictly positive.")
            if not (0.0 <= self.voltage_reset_fraction <= 1.0):
                raise ValueError("Voltage fraction coefficient following spike must be within [0.0, 1.0].")

        if self.has_asc:
            n = len(self.asc_decay)
            if not (len(self.asc_init) == n and len(self.asc_amps) == n and len(self.asc_r) == n):
                raise ValueError(
                    "All after spike current parameters (asc_init, asc_decay, asc_amps, asc_r) "
                    "must have the same size."
                )
            for k_val in self.asc_decay:
                if k_val <= 0.0:
                    raise ValueError("After-spike current time constant must be strictly positive.")
            for r_val in self.asc_r:
                if not (0.0 <= r_val <= 1.0):
                    raise ValueError(
                        "After spike current fraction coefficients r must be within [0.0, 1.0]."
                    )

        if self.has_theta_voltage:
            if self.th_voltage_decay <= 0.0:
                raise ValueError("Voltage-induced threshold time constant must be strictly positive.")

        if len(self.tau_syn) != len(self.E_rev):
            raise ValueError(
                "tau_syn and E_rev must have the same size. "
                "Got %d and %d." % (len(self.tau_syn), len(self.E_rev))
            )

        for tau in self.tau_syn:
            if tau <= 0.0:
                raise ValueError("All synaptic time constants must be strictly positive.")

[docs] def init_state(self, **kwargs): r"""Initialize persistent and short-term state variables. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() dt_ms = float(np.asarray(u.get_mantissa(dt / u.ms))) V = braintools.init.param(self.V_initializer, self.varshape) self.V = brainstate.HiddenState(V) # Per-receptor alpha-function conductance states: dg and g self.g_syn = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.nS), self.varshape) ) for _ in range(self._n_receptors) ] self.dg_syn = [ brainstate.HiddenState( braintools.init.param(braintools.init.Constant(0.0 * u.nS / u.ms), self.varshape) ) for _ in range(self._n_receptors) ] self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) # GLIF-specific state as HiddenState (JAX-traceable) n_asc = len(self.asc_decay) self._asc_states = [ brainstate.HiddenState(jnp.full(self.varshape, self.asc_init[a], dtype=dftype)) for a in range(n_asc) ] asc_sum_init = float(sum(self.asc_init[:n_asc])) self._asc_sum_state = brainstate.HiddenState( jnp.full(self.varshape, asc_sum_init, dtype=dftype) ) # Threshold components (relative to E_L) as HiddenState E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV))) th_inf = float(np.asarray(u.get_mantissa(self.V_th / u.mV))) - E_L_mV self._th_inf = th_inf self._threshold_spike_state = brainstate.HiddenState( jnp.zeros(self.varshape, dtype=dftype) ) self._threshold_voltage_state = brainstate.HiddenState( jnp.zeros(self.varshape, dtype=dftype) ) self._threshold_state = brainstate.HiddenState( jnp.full(self.varshape, th_inf, dtype=dftype) ) # Pre-compute decay rates (constants, computed once) G = float(np.asarray(u.get_mantissa(self.g_m / u.nS))) C_m_val = float(np.asarray(u.get_mantissa(self.C_m / u.pF))) t_ref_ms = float(np.asarray(u.get_mantissa(self.t_ref / u.ms))) if self.has_theta_spike: self._decay_spike = np.exp(-self.th_spike_decay * dt_ms) self._decay_spike_refr = np.exp(-self.th_spike_decay * t_ref_ms) if self.has_asc: self._asc_decay_rates = [np.exp(-self.asc_decay[a] * dt_ms) for a in range(n_asc)] self._asc_stable_coeff = [ ((1.0 / self.asc_decay[a]) / dt_ms) * (1.0 - self._asc_decay_rates[a]) for a in range(n_asc) ] self._asc_refr_decay_rates = [ self.asc_r[a] * np.exp(-self.asc_decay[a] * t_ref_ms) for a in range(n_asc) ] if self.has_theta_voltage: self._potential_decay_rate = np.exp(-G * dt_ms / C_m_val) self._theta_voltage_decay_rate_inv = 1.0 / np.exp(self.th_voltage_decay * dt_ms) self._phi = self.th_voltage_index / (self.th_voltage_decay - G / C_m_val) self._abpara_ratio = self.th_voltage_index / self.th_voltage_decay
# Backward-compatible properties for threshold components @property def _threshold(self): return self._threshold_state.value @property def _threshold_spike(self): return self._threshold_spike_state.value @property def _threshold_voltage(self): return self._threshold_voltage_state.value
[docs] def get_spike(self, V: ArrayLike = None): r"""Generate surrogate spike signal from membrane potential. Computes a differentiable spike signal by scaling membrane potential relative to threshold range and applying the surrogate gradient function. Parameters ---------- V : ArrayLike, optional Membrane potential in mV. If None, uses ``self.V.value``. Shape: ``(*batch_dims, *in_size)``. Returns ------- spike : ArrayLike Surrogate spike output in [0, 1], same shape as input. Values near 1 indicate spiking neurons. Notes ----- Scaling: :math:`v_\mathrm{scaled} = (V - V_\mathrm{th}) / (V_\mathrm{th} - V_\mathrm{reset})` This method is used internally for gradient computation but does not affect the discrete spike logic in ``update()``. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V_rel (mV), plus dg_<k> and g_<k> for each receptor port k (nS and nS respectively) — ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, V_reset_rel, G, C_m, E_L, I_e, asc_sum — mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 V = u.math.where(is_refractory, extra.V_reset_rel, state.V_rel) # Synaptic current: I_syn = sum_k g_k * (V + E_L - E_rev_k) # Units: nS * mV = pA I_syn = jnp.zeros_like(u.get_mantissa(state.V_rel)) * u.nS * u.mV for k in range(self._n_receptors): g_k = state['g_%d' % k] I_syn = I_syn + g_k * (V + extra.E_L - self.E_rev[k] * u.mV) # Leak current: I_leak = G * V (V is relative to E_L) I_leak = extra.G * V # dV/dt dV_raw = (-I_leak - I_syn + extra.I_e + extra.i_stim + extra.asc_sum) / extra.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) derivs = DotDict(V_rel=dV) # Alpha function dynamics for each receptor for k in range(self._n_receptors): dg_k = state['dg_%d' % k] g_k = state['g_%d' % k] tau_k = self.tau_syn[k] * u.ms derivs['dg_%d' % k] = -dg_k / tau_k derivs['g_%d' % k] = dg_k - g_k / tau_k return derivs def _event_fn(self, state, extra, accept): """In-loop refractory clamping and stability check. Spike detection and reset are handled in the post-integration numpy loop (matching NEST's structure where spike detection happens after the ODE integration step, not during it). Parameters ---------- state : DotDict ODE state variables including V_rel and per-receptor dg/g. extra : DotDict Auxiliary data including r, unstable, etc. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated state. """ unstable = extra.unstable | jnp.any( accept & (state.V_rel < -1e3 * u.mV) ) # During refractory: clamp V to V_reset, dV=0 is handled by _vector_field refr_accept = accept & (extra.r > 0) new_V_rel = u.math.where(refr_accept, extra.V_reset_rel, state.V_rel) new_state = DotDict({**state, 'V_rel': new_V_rel}) new_extra = DotDict({**extra, 'unstable': unstable}) return new_state, new_extra def _collect_receptor_delta_inputs(self): r"""Collect delta inputs per receptor port using label-based routing. Uses ``sum_delta_inputs(label='receptor_k')`` to collect conductance jumps for each receptor port k. This method is JIT-traceable: when running inside ``brainstate.transform.for_loop``, delta inputs added via ``add_delta_input(..., label='receptor_k')`` are accumulated as JAX operations. Returns ------- dg : list[Quantity] List of length ``n_receptors``. Each element is a JAX Quantity in nS, shape ``(*in_size)``, containing conductance jumps for that receptor port. """ return [ self.sum_delta_inputs( jnp.zeros(self.varshape) * u.nS, label=f'receptor_{k}', ) for k in range(self._n_receptors) ]
[docs] def update(self, x=0.0 * u.pA): r"""Perform a single simulation step with GLIF dynamics. Executes the full GLIF update cycle: ODE integration via RKF45, threshold computation (spike/voltage-dependent components if enabled), spike detection, reset rules, refractory handling, and synaptic input application. This method is JIT-traceable and compatible with ``brainstate.transform.for_loop``. Parameters ---------- x : ArrayLike, optional External current input in pA. Shape: scalar or ``(*in_size,)``. Applied with one-step delay (buffered to ``I_stim`` and used in next step). Default: 0.0 pA. Returns ------- spike : jax.Array Binary spike output (float32), shape ``(*in_size)``. See Also -------- get_spike : Compute surrogate spike signal for gradient computation """ t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Python-level constants (concrete values, not JAX-traced) E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV))) G = float(np.asarray(u.get_mantissa(self.g_m / u.nS))) C_m_val = float(np.asarray(u.get_mantissa(self.C_m / u.pF))) V_reset_rel = float(np.asarray(u.get_mantissa(self.V_reset / u.mV))) - E_L_mV I_e = float(np.asarray(u.get_mantissa(self.I_e / u.pA))) # JAX state (traced under for_loop) r = self.refractory_step_count.value # int array, varshape i_stim_pA = u.get_mantissa(self.I_stim.value / u.pA) # float array, varshape asc_sum_pA = self._asc_sum_state.value # float array, varshape # Snapshot V_rel before ODE integration (needed for voltage-dependent threshold) V_rel_old = jax.lax.stop_gradient( u.get_mantissa(self.V.value / u.mV) - E_L_mV ) # Buffer new external current (one-step delay) new_i_stim_q = self.sum_current_inputs(x, self.V.value) # ---- Adaptive RKF45 ODE integration ---- ode_state = DotDict(V_rel=u.get_mantissa(self.V.value / u.mV - E_L_mV) * u.mV) for k in range(self._n_receptors): ode_state['dg_%d' % k] = self.dg_syn[k].value ode_state['g_%d' % k] = self.g_syn[k].value extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim_pA * u.pA, V_reset_rel=V_reset_rel * u.mV, G=G * u.nS, C_m=C_m_val * u.pF, E_L=E_L_mV * u.mV, I_e=I_e * u.pA, asc_sum=asc_sum_pA * u.pA, threshold=self._threshold_state.value, v_old=V_rel_old * u.mV, ) ode_state, h, extra = self.integrator( state=ode_state, h=self.integration_step.value, extra=extra ) brainstate.transform.jit_error_if( jnp.any(extra.unstable), 'Numerical instability in glif_cond dynamics.' ) V_rel_new = u.get_mantissa(ode_state.V_rel / u.mV) # JAX array is_refractory = r > 0 # ---- Vectorised GLIF post-integration (JAX) ---- n_asc = len(self.asc_decay) # 1. Spike threshold decay (non-refractory only) if self.has_theta_spike: tspk = self._threshold_spike_state.value tspk = jnp.where(is_refractory, tspk, tspk * self._decay_spike) else: tspk = jnp.zeros(self.varshape, dtype=dftype) # 2. ASC time-averaged sum and decay (non-refractory only) if self.has_asc: asc_sum_new = jnp.zeros(self.varshape, dtype=dftype) asc_decayed = [] for a in range(n_asc): asc_a = self._asc_states[a].value asc_sum_new = asc_sum_new + self._asc_stable_coeff[a] * asc_a asc_decayed.append(asc_a * self._asc_decay_rates[a]) asc_sum_final = jnp.where(is_refractory, asc_sum_pA, asc_sum_new) else: asc_sum_final = jnp.zeros(self.varshape, dtype=dftype) asc_decayed = [] # 3. Voltage-dependent threshold (non-refractory only) if self.has_theta_voltage: tvlt = self._threshold_voltage_state.value beta = (I_e + i_stim_pA + asc_sum_final) / G # pA/nS = mV tvlt_new = ( self._phi * (V_rel_old - beta) * self._potential_decay_rate + self._theta_voltage_decay_rate_inv * ( tvlt - self._phi * (V_rel_old - beta) - self._abpara_ratio * beta ) + self._abpara_ratio * beta ) tvlt = jnp.where(is_refractory, tvlt, tvlt_new) else: tvlt = jnp.zeros(self.varshape, dtype=dftype) # 4. Total threshold threshold = tspk + tvlt + self._th_inf # 5. Spike check (non-refractory only) spiked = (V_rel_new > threshold) & ~is_refractory # 6. On spike: update ASC (using already-decayed values, matching NEST) if self.has_asc: for a in range(n_asc): asc_a = self._asc_states[a].value asc_reset = self.asc_amps[a] + asc_decayed[a] * self._asc_refr_decay_rates[a] self._asc_states[a].value = jnp.where( spiked, asc_reset, jnp.where(is_refractory, asc_a, asc_decayed[a]) ) self._asc_sum_state.value = asc_sum_final # 7. Voltage reset if not self.has_theta_spike: # GLIF1/3: simple reset V_final_rel = jnp.where( spiked, V_reset_rel, jnp.where(is_refractory, V_rel_old, V_rel_new) ) else: # GLIF2/4/5: biologically defined reset V_reset_bio = self.voltage_reset_fraction * V_rel_old + self.voltage_reset_add V_final_rel = jnp.where( spiked, V_reset_bio, jnp.where(is_refractory, V_rel_old, V_rel_new) ) # Reset spike threshold on spike tspk_reset = tspk * self._decay_spike_refr + self.th_spike_add tspk = jnp.where(spiked, tspk_reset, tspk) # Update total threshold after spike reset threshold = jnp.where(spiked, tspk + tvlt + self._th_inf, threshold) # 8. Refractory counter r_new = jnp.where( spiked, self.ref_count, jnp.where(is_refractory, r - 1, r) ) # 9. Collect receptor delta inputs (JAX-compatible via sum_delta_inputs) dg_input = self._collect_receptor_delta_inputs() cond_init_vals = [np.e / self.tau_syn[k] for k in range(self._n_receptors)] # 10. Write back all state self.V.value = (V_final_rel + E_L_mV) * u.mV for k in range(self._n_receptors): dg_k = ode_state['dg_%d' % k] g_k = ode_state['g_%d' % k] # Add incoming conductance jump (e/tau_syn scaling) dg_jump = u.get_mantissa(dg_input[k] / u.nS) * cond_init_vals[k] dg_k = dg_k + dg_jump * (u.nS / u.ms) self.dg_syn[k].value = dg_k self.g_syn[k].value = g_k self._threshold_spike_state.value = tspk self._threshold_voltage_state.value = tvlt self._threshold_state.value = threshold self.refractory_step_count.value = jnp.asarray(r_new, dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim_q + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spiked, t + dt_q, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) return jnp.asarray(spiked, dtype=jnp.float32)