Source code for brainpy_state._nest.iaf_chxk_2008

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'iaf_chxk_2008',
]


class iaf_chxk_2008(NESTNeuron):
    r"""NEST-compatible ``iaf_chxk_2008`` with alpha synapses and precise AHP timing.

    Description
    -----------

    ``iaf_chxk_2008`` is a conductance-based leaky integrate-and-fire neuron
    with alpha-function excitatory/inhibitory synaptic conductances and a
    spike-triggered after-hyperpolarization (AHP) conductance, developed for
    modeling retina-LGN transmission (Casti et al., 2008). The implementation
    follows NEST ``models/iaf_chxk_2008.{h,cpp}`` semantics: adaptive RKF45
    integration, threshold crossing from below, precise in-step spike timing via
    linear interpolation, spike-triggered AHP kicks with exact sub-step decay,
    and optional ``ahp_bug`` mode that reproduces the historical single-AHP
    behavior from the original Fortran code.

    **1. Membrane and conductance dynamics**

    Let :math:`V_m` be membrane potential (mV), :math:`g_\mathrm{ex}`,
    :math:`g_\mathrm{in}`, :math:`g_\mathrm{ahp,state}` be conductance states
    (nS), and :math:`I_\mathrm{stim}` be the one-step buffered external current
    (pA). Subthreshold dynamics are

    .. math::

       \frac{dV_m}{dt} =
       \frac{-I_\mathrm{leak} - I_{\mathrm{syn,ex}} - I_{\mathrm{syn,in}}
             - I_\mathrm{ahp} + I_e + I_\mathrm{stim}}{C_m},

    where

    .. math::

       I_\mathrm{leak} = g_L (V_m - E_L),
       \quad
       I_{\mathrm{syn,ex}} = g_\mathrm{ex}(V_m - E_\mathrm{ex}),

    .. math::

       I_{\mathrm{syn,in}} = g_\mathrm{in}(V_m - E_\mathrm{in}),
       \quad
       I_\mathrm{ahp} = g_\mathrm{ahp,state}(V_m - E_\mathrm{ahp}).

    Each conductance channel (excitatory, inhibitory, AHP) evolves as an
    alpha-function state pair :math:`(dg, g_\mathrm{state})`:

    .. math::

       \frac{d\,dg}{dt} = -\frac{dg}{\tau},
       \qquad
       \frac{dg_\mathrm{state}}{dt} = dg - \frac{g_\mathrm{state}}{\tau}.

    Incoming spike weights (nS) are interpreted with sign convention: positive
    weights drive excitatory channel, negative weights (absolute value) drive
    inhibitory channel. Jumps are applied to :math:`dg` with NEST normalization:

    .. math::

       dg_\mathrm{ex} \leftarrow dg_\mathrm{ex} + \frac{e}{\tau_\mathrm{ex}} w_+,
       \qquad
       dg_\mathrm{in} \leftarrow dg_\mathrm{in} + \frac{e}{\tau_\mathrm{in}} |w_-|.

    **2. Precise output spike timing and AHP kick**

    A spike is emitted only on threshold crossing from below:

    .. math::

       V_m(t_k^-) < V_{th} \;\wedge\; V_m(t_k^+) \ge V_{th}.

    When a crossing is detected, the precise in-step spike time is computed by
    linear interpolation. Let :math:`dt_\mathrm{spike}` be the duration from
    spike time to step end:

    .. math::

       dt_\mathrm{spike}
       = h \frac{V_m(t_k^+) - V_{th}}{V_m(t_k^+) - V_m(t_k^-)},

    where :math:`h` is the step size. The AHP alpha is initialized at spike
    time and decayed forward to step end:

    .. math::

       \Delta dg_\mathrm{ahp}
       = \frac{g_\mathrm{ahp} e}{\tau_\mathrm{ahp}}
         \exp\!\left(-\frac{dt_\mathrm{spike}}{\tau_\mathrm{ahp}}\right),

    .. math::

       \Delta g_\mathrm{ahp,state}
       = \Delta dg_\mathrm{ahp}\, dt_\mathrm{spike}.

    If ``ahp_bug=True``, these values **replace** the current AHP state (single
    AHP mode); otherwise they are **added** (multiple AHP accumulation).

    **3. Numerical integration via RKF45**

    The seven coupled ODEs (:math:`V_m`, three :math:`dg` states, three
    :math:`g_\mathrm{state}` variables) are integrated using Runge-Kutta-Fehlberg
    4(5) with adaptive step size control. Local truncation error is estimated
    by comparing 4th and 5th order solutions and step size is adjusted to keep
    error below ``gsl_error_tol``. Minimum step size is ``_MIN_H = 1e-8`` ms and
    iteration limit is ``_MAX_ITERS = 10000`` per global step.

    **4. Update order matching NEST semantics**

    Each simulation step follows NEST ordering:

    1. Integrate all ODE states over :math:`[t, t+dt]` via RKF45.
    2. Check threshold crossing from below; if crossed, compute precise spike
       time and apply AHP kick at that time (with ``ahp_bug`` mode if enabled).
    3. Apply arriving signed spike weights to :math:`dg_\mathrm{ex}` and
       :math:`dg_\mathrm{in}` after integration completes.
    4. Store incoming continuous current ``x`` into buffered ``I_stim`` for
       next step (NEST current-event timing convention).

    **5. Assumptions, constraints, and failure modes**

    - Parameters are scalar or broadcastable to ``self.varshape``.
    - Construction-time constraints enforce ``C_m > 0``, ``tau_syn_ex > 0``,
      ``tau_syn_in > 0``, ``tau_ahp > 0``.
    - No explicit reset or refractory period: neuron can spike repeatedly if
      voltage remains above threshold.
    - Adaptive integration can fail if ``_MAX_ITERS`` is exceeded; in practice
      this is rare with reasonable parameter values.
    - Continuous input ``x`` passed to :meth:`update` is delayed by one step
      via ``I_stim`` (ring-buffer semantics), while spike events are applied
      after ODE integration.
    - Per-step complexity is :math:`O(|\mathrm{state}| \cdot K_\mathrm{iter})`
      where :math:`K_\mathrm{iter}` is the number of RKF45 substeps (typically
      1-5 per global step).

    Parameters
    ----------
    in_size : Size
        Population shape specification. Model parameters and states are
        broadcast to ``self.varshape`` derived from ``in_size``.
    V_th : ArrayLike, optional
        Spike threshold voltage :math:`V_{th}` in mV, broadcastable to
        ``self.varshape``. Default is ``-45. * u.mV``.
    g_L : ArrayLike, optional
        Leak conductance :math:`g_L` in nS, broadcastable to ``self.varshape``.
        Default is ``100. * u.nS``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF, broadcastable to
        ``self.varshape``. Must be strictly positive elementwise.
        Default is ``1000. * u.pF``.
    E_ex : ArrayLike, optional
        Excitatory reversal potential :math:`E_\mathrm{ex}` in mV,
        broadcastable to ``self.varshape``. Default is ``20. * u.mV``.
    E_in : ArrayLike, optional
        Inhibitory reversal potential :math:`E_\mathrm{in}` in mV,
        broadcastable to ``self.varshape``. Default is ``-90. * u.mV``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV, broadcastable to
        ``self.varshape``. Default is ``-60. * u.mV``.
    tau_syn_ex : ArrayLike, optional
        Excitatory alpha time constant :math:`\tau_\mathrm{ex}` in ms,
        broadcastable to ``self.varshape``. Must be strictly positive
        elementwise. Default is ``1. * u.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory alpha time constant :math:`\tau_\mathrm{in}` in ms,
        broadcastable to ``self.varshape``. Must be strictly positive
        elementwise. Default is ``1. * u.ms``.
    I_e : ArrayLike, optional
        Constant external current :math:`I_e` in pA, broadcastable to
        ``self.varshape``. Added in each integration substep.
        Default is ``0. * u.pA``.
    tau_ahp : ArrayLike, optional
        AHP alpha time constant :math:`\tau_\mathrm{ahp}` in ms,
        broadcastable to ``self.varshape``. Must be strictly positive
        elementwise. Default is ``0.5 * u.ms``.
    E_ahp : ArrayLike, optional
        AHP reversal potential :math:`E_\mathrm{ahp}` in mV, broadcastable to
        ``self.varshape``. Default is ``-95. * u.mV``.
    g_ahp : ArrayLike, optional
        AHP kick conductance scale :math:`g_\mathrm{ahp}` in nS,
        broadcastable to ``self.varshape``. Controls magnitude of AHP alpha
        initialized at each spike. Default is ``443.8 * u.nS``.
    ahp_bug : ArrayLike, optional
        Boolean flag (broadcastable to ``self.varshape``) enabling historical
        single-AHP bug mode. If ``True``, each spike replaces existing AHP
        state with new AHP kick. If ``False``, AHP kicks accumulate.
        Default is ``False``.
    gsl_error_tol : ArrayLike, optional
        Unitless local RKF45 error tolerance, broadcastable and strictly positive.
        Default is ``1e-3``.
    V_initializer : Callable, optional
        Initializer used by :meth:`init_state` for membrane potential ``V``.
        Must return mV-compatible values with shape compatible with
        ``self.varshape``.
        Default is ``braintools.init.Constant(-60. * u.mV)``.
    g_ex_initializer : Callable, optional
        Initializer for excitatory conductance state ``g_ex`` (nS).
        Default is ``braintools.init.Constant(0. * u.nS)``.
    g_in_initializer : Callable, optional
        Initializer for inhibitory conductance state ``g_in`` (nS).
        Default is ``braintools.init.Constant(0. * u.nS)``.
    g_ahp_initializer : Callable, optional
        Initializer for AHP conductance state ``g_ahp_state`` (nS).
        Default is ``braintools.init.Constant(0. * u.nS)``.
    spk_fun : Callable, optional
        Surrogate spike function used by :meth:`get_spike` and
        :meth:`update`. Receives normalized threshold distance tensor.
        Default is ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy forwarded to :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` matches NEST behavior. Default is ``'hard'``.
    name : str or None, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 18 28 14 15 35

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines ``self.varshape`` for parameter/state broadcasting.
       * - ``V_th``
         - ArrayLike, broadcastable to ``self.varshape`` (mV)
         - ``-45. * u.mV``
         - :math:`V_{th}`
         - Spike threshold voltage.
       * - ``g_L``
         - ArrayLike, broadcastable (nS)
         - ``100. * u.nS``
         - :math:`g_L`
         - Leak conductance.
       * - ``C_m``
         - ArrayLike, broadcastable (pF), ``> 0``
         - ``1000. * u.pF``
         - :math:`C_m`
         - Membrane capacitance.
       * - ``E_ex``
         - ArrayLike, broadcastable (mV)
         - ``20. * u.mV``
         - :math:`E_\mathrm{ex}`
         - Excitatory reversal potential.
       * - ``E_in``
         - ArrayLike, broadcastable (mV)
         - ``-90. * u.mV``
         - :math:`E_\mathrm{in}`
         - Inhibitory reversal potential.
       * - ``E_L``
         - ArrayLike, broadcastable (mV)
         - ``-60. * u.mV``
         - :math:`E_L`
         - Resting potential.
       * - ``tau_syn_ex``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``1. * u.ms``
         - :math:`\tau_\mathrm{ex}`
         - Excitatory alpha time constant.
       * - ``tau_syn_in``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``1. * u.ms``
         - :math:`\tau_\mathrm{in}`
         - Inhibitory alpha time constant.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * u.pA``
         - :math:`I_e`
         - Constant external current.
       * - ``tau_ahp``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``0.5 * u.ms``
         - :math:`\tau_\mathrm{ahp}`
         - AHP alpha time constant.
       * - ``E_ahp``
         - ArrayLike, broadcastable (mV)
         - ``-95. * u.mV``
         - :math:`E_\mathrm{ahp}`
         - AHP reversal potential.
       * - ``g_ahp``
         - ArrayLike, broadcastable (nS)
         - ``443.8 * u.nS``
         - :math:`g_\mathrm{ahp}`
         - AHP kick conductance scale.
       * - ``ahp_bug``
         - ArrayLike broadcastable bool
         - ``False``
         - --
         - Enable single-AHP historical bug mode.
       * - ``gsl_error_tol``
         - ArrayLike, broadcastable, unitless, ``> 0``
         - ``1e-3``
         - --
         - Local absolute tolerance for the embedded RKF45 error estimate.
       * - ``V_initializer``
         - Callable returning mV-compatible values
         - ``Constant(-60. * u.mV)``
         - --
         - Initializes membrane state ``V``.
       * - ``g_ex_initializer``
         - Callable returning nS-compatible values
         - ``Constant(0. * u.nS)``
         - --
         - Initializes excitatory conductance.
       * - ``g_in_initializer``
         - Callable returning nS-compatible values
         - ``Constant(0. * u.nS)``
         - --
         - Initializes inhibitory conductance.
       * - ``g_ahp_initializer``
         - Callable returning nS-compatible values
         - ``Constant(0. * u.nS)``
         - --
         - Initializes AHP conductance state.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate spike output nonlinearity.
       * - ``spk_reset``
         - str
         - ``'hard'``
         - --
         - Reset mode inherited from base ``Neuron``.
       * - ``name``
         - str | None
         - ``None``
         - --
         - Optional node name.

    Raises
    ------
    ValueError
        If validated constraints fail (non-positive capacitance, non-positive
        time constants, non-positive gsl_error_tol).
    TypeError
        If provided arguments are incompatible with expected units/callables
        (mV, pA, pF, ms, nS).
    KeyError
        If simulation context values ``t`` and/or ``dt`` are missing when
        :meth:`update` is called.
    AttributeError
        If :meth:`update` is called before :meth:`init_state` creates required
        runtime states.

    Attributes
    ----------
    V : HiddenState
        Membrane potential state in mV.
    dg_ex : ShortTermState
        Excitatory conductance rate-of-change state (nS/ms).
    g_ex : HiddenState
        Excitatory conductance state in nS.
    dg_in : ShortTermState
        Inhibitory conductance rate-of-change state (nS/ms).
    g_in : HiddenState
        Inhibitory conductance state in nS.
    dg_ahp : ShortTermState
        AHP conductance rate-of-change state (nS/ms).
    g_ahp_state : HiddenState
        AHP conductance state in nS.
    I_stim : ShortTermState
        One-step buffered external current in pA.
    integration_step : ShortTermState
        Adaptive RKF45 step size hint in ms.
    last_spike_time : ShortTermState
        Absolute precise spike time in ms.

    Notes
    -----
    - The model has no explicit membrane reset or refractory state: after
      crossing threshold, voltage continues to evolve and can spike again.
    - Continuous input ``x`` passed to :meth:`update` is **buffered** and
      affects the **next** step (NEST current-event timing).
    - Like NEST, this model provides precise output spike timing via linear
      interpolation but does not process off-grid spike-input offsets.
    - RKF45 integration is performed via the adaptive integrator and
      written back into BrainUnit states at step end.
    - ``ahp_bug=True`` reproduces the original Fortran behavior where only one
      AHP is tracked; this is primarily for validation against legacy code.

    Examples
    --------
    Create a single neuron with default parameters and simulate:

    .. code-block:: python

        >>> import brainstate as bs
        >>> import saiunit as u
        >>> import brainpy.state as bps
        >>> neuron = bps.iaf_chxk_2008(1)
        >>> with bs.environ.context(dt=0.1 * u.ms):
        ...     neuron.init_state()
        ...     spike = neuron.update(100. * u.pA)  # buffered to next step

    Inspect AHP kick behavior after spike:

    .. code-block:: python

        >>> neuron.V.value  # check membrane potential
        >>> neuron.g_ahp_state.value  # check AHP conductance state

    Recordables
    -----------

    ``V_m``, ``g_ex``, ``g_in``, ``g_ahp``, ``I_syn_ex``, ``I_syn_in``, ``I_ahp``

    References
    ----------
    .. [1] Casti A, Hayot F, Xiao Y, Kaplan E (2008). A simple model of
           retina-LGN transmission. Journal of Computational Neuroscience
           24:235-252. DOI: https://doi.org/10.1007/s10827-007-0053-7
    .. [2] NEST source: ``models/iaf_chxk_2008.h`` and
           ``models/iaf_chxk_2008.cpp``.
    """

    __module__ = 'brainpy.state'

    RECORDABLES = (
        'V_m',
        'g_ex',
        'g_in',
        'g_ahp',
        'I_syn_ex',
        'I_syn_in',
        'I_ahp',
    )

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 10000

    def __init__(
        self,
        in_size: Size,
        V_th: ArrayLike = -45.0 * u.mV,
        g_L: ArrayLike = 100.0 * u.nS,
        C_m: ArrayLike = 1000.0 * u.pF,
        E_ex: ArrayLike = 20.0 * u.mV,
        E_in: ArrayLike = -90.0 * u.mV,
        E_L: ArrayLike = -60.0 * u.mV,
        tau_syn_ex: ArrayLike = 1.0 * u.ms,
        tau_syn_in: ArrayLike = 1.0 * u.ms,
        I_e: ArrayLike = 0.0 * u.pA,
        tau_ahp: ArrayLike = 0.5 * u.ms,
        E_ahp: ArrayLike = -95.0 * u.mV,
        g_ahp: ArrayLike = 443.8 * u.nS,
        ahp_bug: ArrayLike = False,
        gsl_error_tol: ArrayLike = 1e-3,
        V_initializer: Callable = braintools.init.Constant(-60.0 * u.mV),
        g_ex_initializer: Callable = braintools.init.Constant(0.0 * u.nS),
        g_in_initializer: Callable = braintools.init.Constant(0.0 * u.nS),
        g_ahp_initializer: Callable = braintools.init.Constant(0.0 * u.nS),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.V_th = braintools.init.param(V_th, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.E_ex = braintools.init.param(E_ex, self.varshape)
        self.E_in = braintools.init.param(E_in, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.tau_ahp = braintools.init.param(tau_ahp, self.varshape)
        self.E_ahp = braintools.init.param(E_ahp, self.varshape)
        self.g_ahp = braintools.init.param(g_ahp, self.varshape)
        self.ahp_bug = braintools.init.param(ahp_bug, self.varshape)
        self.gsl_error_tol = gsl_error_tol

        self.V_initializer = V_initializer
        self.g_ex_initializer = g_ex_initializer
        self.g_in_initializer = g_in_initializer
        self.g_ahp_initializer = g_ahp_initializer

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=None,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(0)

    @property
    def recordables(self):
        return list(self.RECORDABLES)

    def _validate_parameters(self):
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.C_m, self.tau_syn_ex, self.tau_ahp)):
            return

        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_ahp <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize persistent and short-term state variables. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() V = braintools.init.param(self.V_initializer, self.varshape) g_ex = braintools.init.param(self.g_ex_initializer, self.varshape) g_in = braintools.init.param(self.g_in_initializer, self.varshape) g_ahp_init = braintools.init.param(self.g_ahp_initializer, self.varshape) zeros = u.math.zeros(self.varshape, dtype=V.dtype) * (u.nS / u.ms) zeros_pA = u.math.zeros(self.varshape, dtype=dftype) * u.pA self.V = brainstate.HiddenState(V) self.dg_ex = brainstate.ShortTermState(zeros) self.g_ex = brainstate.HiddenState(g_ex) self.dg_in = brainstate.ShortTermState(zeros) self.g_in = brainstate.HiddenState(g_in) self.dg_ahp = brainstate.ShortTermState(zeros) self.g_ahp_state = brainstate.HiddenState(g_ahp_init) self.I_syn_ex = brainstate.ShortTermState(zeros_pA) self.I_syn_in = brainstate.ShortTermState(zeros_pA) self.I_ahp = brainstate.ShortTermState(zeros_pA) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms, dtype=dftype)) self.last_spike_offset = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.ms, dtype=dftype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike output from membrane voltage. Parameters ---------- V : ArrayLike, optional Voltage values with shape broadcastable to ``self.varshape`` and units compatible with mV. If ``None``, uses current state ``self.V.value``. Returns ------- ArrayLike Surrogate spike activation produced by ``spk_fun((V - V_th) / |V_th - E_L|)``. """ V = self.V.value if V is None else V denom = u.math.abs(self.V_th - self.E_L) + 1e-12 * u.mV v_scaled = (V - self.V_th) / denom return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, dg_ex, g_ex, dg_in, g_in, dg_ahp, g_ahp_state — ODE state variables. extra : DotDict Keys: i_stim — buffered external current for this step. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ i_leak = self.g_L * (state.V - self.E_L) i_syn_exc = state.g_ex * (state.V - self.E_ex) i_syn_inh = state.g_in * (state.V - self.E_in) i_ahp = state.g_ahp_state * (state.V - self.E_ahp) dV = (-i_leak - i_syn_exc - i_syn_inh - i_ahp + self.I_e + extra.i_stim) / self.C_m ddg_ex = -state.dg_ex / self.tau_syn_ex dg_ex_dt = state.dg_ex - state.g_ex / self.tau_syn_ex ddg_in = -state.dg_in / self.tau_syn_in dg_in_dt = state.dg_in - state.g_in / self.tau_syn_in ddg_ahp = -state.dg_ahp / self.tau_ahp dg_ahp_dt = state.dg_ahp - state.g_ahp_state / self.tau_ahp return DotDict( V=dV, dg_ex=ddg_ex, g_ex=dg_ex_dt, dg_in=ddg_in, g_in=dg_in_dt, dg_ahp=ddg_ahp, g_ahp_state=dg_ahp_dt, )
[docs] def update(self, x=0.0 * u.pA, w_ex=None, w_in=None): r"""Advance the neuron by one simulation step. Parameters ---------- x : ArrayLike, optional Continuous external current input in pA, broadcastable to ``self.varshape``. This value is stored into ``I_stim`` and applied at the next simulation step (one-step delay). w_ex : ArrayLike or None, optional Excitatory synaptic weight increment (nS) to add to ``dg_ex`` after integration, scaled by ``e/tau_syn_ex``. When ``None`` (default), the value is read from registered delta inputs with label ``'w_ex'``. Provide an explicit array for JIT-compatible (for_loop) usage. w_in : ArrayLike or None, optional Inhibitory synaptic weight increment (nS), analogous to ``w_ex`` but for ``dg_in`` with label ``'w_in'``. Returns ------- jax.Array Binary spike tensor with dtype ``jnp.float64`` and shape ``self.V.value.shape``. A value of ``1.0`` indicates a threshold crossing from below during the integrated interval :math:`(t, t+dt]`. Notes ----- Integration uses an adaptive RKF45 loop. Spike detection and AHP kicks follow NEST semantics: crossing is checked at the *global* step level (comparing V before and after the full integration), and the AHP state is updated post-integration using linear interpolation of the spike time. Synaptic inputs (``w_ex``, ``w_in``) are applied after integration. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() # Read state variables with their natural units. V = self.V.value # mV V_start = V # saved for global-step spike offset computation dg_ex = self.dg_ex.value # nS/ms g_ex = self.g_ex.value # nS dg_in = self.dg_in.value # nS/ms g_in = self.g_in.value # nS dg_ahp = self.dg_ahp.value # nS/ms g_ahp_state = self.g_ahp_state.value # nS i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration (no per-substep event callback). ode_state = DotDict( V=V, dg_ex=dg_ex, g_ex=g_ex, dg_in=dg_in, g_in=g_in, dg_ahp=dg_ahp, g_ahp_state=g_ahp_state, ) extra = DotDict(i_stim=i_stim) ode_state, h, _ = self.integrator(state=ode_state, h=h, extra=extra) V = ode_state.V dg_ex, g_ex = ode_state.dg_ex, ode_state.g_ex dg_in, g_in = ode_state.dg_in, ode_state.g_in dg_ahp, g_ahp_state = ode_state.dg_ahp, ode_state.g_ahp_state # Global-step spike detection: threshold crossing from below only. crossed = (V_start < self.V_th) & (V >= self.V_th) # Global-step spike-offset interpolation: time from spike to step end. denom = V - V_start safe_denom = u.math.where(u.math.abs(denom) < 1e-30 * u.mV, 1.0 * u.mV, denom) spike_offset = dt * (V - self.V_th) / safe_denom spike_offset = u.math.clip(spike_offset, 0.0 * u.ms, dt) spike_offset = u.math.where(crossed, spike_offset, 0.0 * u.ms) # Apply AHP kick post-integration (matches NEST reference semantics). pscon_ahp = self.g_ahp * np.e / self.tau_ahp # nS/ms delta_dg_ahp = pscon_ahp * u.math.exp(-spike_offset / self.tau_ahp) delta_g_ahp = delta_dg_ahp * spike_offset ahp_bug_on = crossed & jnp.asarray(self.ahp_bug) ahp_bug_off = crossed & jnp.logical_not(jnp.asarray(self.ahp_bug)) new_dg_ahp = u.math.where(ahp_bug_on, delta_dg_ahp, dg_ahp) new_dg_ahp = u.math.where(ahp_bug_off, new_dg_ahp + delta_dg_ahp, new_dg_ahp) new_g_ahp = u.math.where(ahp_bug_on, delta_g_ahp, g_ahp_state) new_g_ahp = u.math.where(ahp_bug_off, new_g_ahp + delta_g_ahp, new_g_ahp) dg_ahp = new_dg_ahp g_ahp_state = new_g_ahp # Compute recordable synaptic currents (post-integration, pre-weight-update). I_syn_ex = g_ex * (V - self.E_ex) # nS * mV = pA I_syn_in = g_in * (V - self.E_in) # nS * mV = pA I_ahp_cur = g_ahp_state * (V - self.E_ahp) # nS * mV = pA # Synaptic spike inputs (applied after integration). if w_ex is None: w_ex = self.sum_delta_inputs(u.math.zeros_like(self.g_ex.value), label='w_ex') if w_in is None: w_in = self.sum_delta_inputs(u.math.zeros_like(self.g_in.value), label='w_in') pscon_ex = np.e / self.tau_syn_ex # 1/ms pscon_in = np.e / self.tau_syn_in # 1/ms dg_ex = dg_ex + pscon_ex * w_ex # nS/ms dg_in = dg_in + pscon_in * w_in # nS/ms # Update spike-time and spike-offset states. new_spike_offset = u.math.where(crossed, spike_offset, self.last_spike_offset.value) new_spike_time = u.math.where(crossed, t + dt - spike_offset, self.last_spike_time.value) # Write back state. self.V.value = V self.dg_ex.value = dg_ex self.g_ex.value = g_ex self.dg_in.value = dg_in self.g_in.value = g_in self.dg_ahp.value = dg_ahp self.g_ahp_state.value = g_ahp_state self.I_syn_ex.value = I_syn_ex self.I_syn_in.value = I_syn_in self.I_ahp.value = I_ahp_cur self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA self.last_spike_offset.value = jax.lax.stop_gradient(new_spike_offset) self.last_spike_time.value = jax.lax.stop_gradient(new_spike_time) return u.math.asarray(crossed, dtype=dftype)