Source code for brainpy_state._nest.aeif_psc_exp

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, validate_aeif_overflow, AdaptiveRungeKuttaStep

__all__ = [
    'aeif_psc_exp',
]


class aeif_psc_exp(NESTNeuron):
    r"""NEST-compatible adaptive exponential integrate-and-fire neuron with exponential synapses.

    Current-based adaptive exponential integrate-and-fire neuron with exponentially
    decaying synaptic currents. Implements the AdEx model of Brette & Gerstner (2005)
    with spike-triggered adaptation, subthreshold adaptation coupling, and separate
    excitatory/inhibitory exponential current synapses. Follows NEST
    ``models/aeif_psc_exp.{h,cpp}`` implementation exactly.

    **1. Mathematical Model**

    **Membrane and adaptation dynamics:**

    The membrane potential :math:`V` and adaptation current :math:`w` evolve as:

    .. math::

       C_m \frac{dV}{dt}
       =
       -g_L (V - E_L)
       + g_L \Delta_T \exp\!\left(\frac{V - V_{th}}{\Delta_T}\right)
       - w + I_{ex} - I_{in} + I_e + I_{stim}

    .. math::

       \tau_w \frac{dw}{dt} = a (V - E_L) - w

    where :math:`C_m` is membrane capacitance, :math:`g_L` is leak conductance,
    :math:`E_L` is leak reversal, :math:`\Delta_T` is the exponential slope factor,
    :math:`V_{th}` is the spike threshold, :math:`a` couples subthreshold voltage to
    adaptation, and :math:`\tau_w` is the adaptation time constant.

    **Synaptic current dynamics:**

    Excitatory and inhibitory currents decay exponentially:

    .. math::

       \frac{d I_{ex}}{dt} = -\frac{I_{ex}}{\tau_{syn,ex}},
       \qquad
       \frac{d I_{in}}{dt} = -\frac{I_{in}}{\tau_{syn,in}}

    Incoming spike weights (in pA) are split by sign and applied instantaneously:

    .. math::

       I_{ex} \leftarrow I_{ex} + \max(w, 0),
       \qquad
       I_{in} \leftarrow I_{in} + \max(-w, 0)

    **2. Refractory and Spike Handling (NEST Semantics)**

    During refractory period (:math:`r > 0` steps remaining), the effective voltage
    used in the RHS is clamped to :math:`V_{\text{reset}}` and :math:`dV/dt = 0`.
    Outside refractory, the effective voltage is :math:`\min(V, V_{\text{peak}})`.

    Spike detection threshold:
      - :math:`V_{\text{peak}}` if :math:`\Delta_T > 0` (exponential regime)
      - :math:`V_{th}` if :math:`\Delta_T = 0` (integrate-and-fire limit)

    On each detected spike:
      1. :math:`V \leftarrow V_{\text{reset}}`
      2. :math:`w \leftarrow w + b` (spike-triggered adaptation increment)
      3. Refractory counter set to ``refractory_counts + 1`` (if ``t_ref > 0``)

    Spike detection/reset occurs *inside* the RKF45 substep loop. With ``t_ref = 0``,
    multiple spikes can occur within one simulation step, matching NEST behavior.

    **3. Update Order Per Simulation Step**

    1. Integrate ODEs on :math:`(t, t+dt]` via adaptive RKF45 (Runge-Kutta-Fehlberg 4(5))
    2. Inside integration loop: apply refractory clamp, detect spike, reset, adapt
    3. After integration: decrement refractory counter by 1
    4. Apply arriving spike weights to :math:`I_{ex}`, :math:`I_{in}`
    5. Store external current input :math:`x` into one-step delayed buffer :math:`I_{\text{stim}}`

    **4. Numerical Integration**

    Uses adaptive RKF45 with local error control. Step size :math:`h` is adjusted
    to keep error below ``gsl_error_tol``. Integration step size is persistent across
    simulation steps for efficiency.

    Parameters
    ----------
    in_size : int or tuple of int
        Population shape. Scalar for 1D population, tuple for multi-dimensional.
    V_peak : ArrayLike, optional
        Spike detection threshold (if ``Delta_T > 0``). Units: mV. Default: 0.0 mV.
        Scalar or broadcastable to ``in_size``.
    V_reset : ArrayLike, optional
        Reset potential after spike. Units: mV. Default: -60.0 mV.
        Scalar or broadcastable to ``in_size``. Must satisfy ``V_reset < V_peak``.
    t_ref : ArrayLike, optional
        Absolute refractory period duration. Units: ms. Default: 0.0 ms.
        Scalar or broadcastable to ``in_size``. Zero allows multiple spikes per step.
    g_L : ArrayLike, optional
        Leak conductance. Units: nS. Default: 30.0 nS.
        Scalar or broadcastable to ``in_size``. Must be positive.
    C_m : ArrayLike, optional
        Membrane capacitance. Units: pF. Default: 281.0 pF.
        Scalar or broadcastable to ``in_size``. Must be positive.
    E_L : ArrayLike, optional
        Leak reversal potential. Units: mV. Default: -70.6 mV.
        Scalar or broadcastable to ``in_size``.
    Delta_T : ArrayLike, optional
        Exponential slope factor. Units: mV. Default: 2.0 mV.
        Scalar or broadcastable to ``in_size``. Zero recovers integrate-and-fire limit.
        Must be non-negative. Large values relative to ``V_peak - V_th`` may cause overflow.
    tau_w : ArrayLike, optional
        Adaptation time constant. Units: ms. Default: 144.0 ms.
        Scalar or broadcastable to ``in_size``. Must be positive.
    a : ArrayLike, optional
        Subthreshold adaptation coupling. Units: nS. Default: 4.0 nS.
        Scalar or broadcastable to ``in_size``. Couples voltage deviation to adaptation.
    b : ArrayLike, optional
        Spike-triggered adaptation increment. Units: pA. Default: 80.5 pA.
        Scalar or broadcastable to ``in_size``. Added to ``w`` on each spike.
    V_th : ArrayLike, optional
        Spike initiation threshold (in exponential term). Units: mV. Default: -50.4 mV.
        Scalar or broadcastable to ``in_size``. Must satisfy ``V_th <= V_peak``.
    tau_syn_ex : ArrayLike, optional
        Excitatory synaptic current time constant. Units: ms. Default: 0.2 ms.
        Scalar or broadcastable to ``in_size``. Must be positive.
    tau_syn_in : ArrayLike, optional
        Inhibitory synaptic current time constant. Units: ms. Default: 2.0 ms.
        Scalar or broadcastable to ``in_size``. Must be positive.
    I_e : ArrayLike, optional
        Constant external current. Units: pA. Default: 0.0 pA.
        Scalar or broadcastable to ``in_size``.
    gsl_error_tol : ArrayLike, optional
        RKF45 local error tolerance. Dimensionless. Default: 1e-6.
        Scalar or broadcastable to ``in_size``. Must be positive. Smaller values
        increase accuracy at the cost of smaller integration steps.
    V_initializer : Callable, optional
        Membrane potential initializer. Default: Constant(-70.6 mV).
        Must return quantity with mV units when called with ``(shape,)``.
    I_ex_initializer : Callable, optional
        Excitatory current initializer. Default: Constant(0.0 pA).
        Must return quantity with pA units when called with ``(shape,)``.
    I_in_initializer : Callable, optional
        Inhibitory current initializer. Default: Constant(0.0 pA).
        Must return quantity with pA units when called with ``(shape,)``.
    w_initializer : Callable, optional
        Adaptation current initializer. Default: Constant(0.0 pA).
        Must return quantity with pA units when called with ``(shape,)``.
    spk_fun : Callable, optional
        Surrogate gradient function for spike generation. Default: ReluGrad().
        Must be a differentiable spike function from ``braintools.surrogate``.
    spk_reset : str, optional
        Spike reset mode. Default: 'hard'.
        - 'hard': Stop gradient through reset (matches NEST behavior)
        - 'soft': Allow gradient through reset
    ref_var : bool, optional
        If True, expose boolean refractory state variable. Default: False.
        When True, creates ``self.refractory`` indicating refractory status.
    name : str, optional
        Model instance name. Default: None (auto-generated).

    Parameter Mapping
    -----------------

    ==================== ================== ========================================== =======================================================
    **Parameter**        **Default**        **Math equivalent**                        **Description**
    ==================== ================== ========================================== =======================================================
    ``in_size``          (required)                                                    Population shape
    ``V_peak``           0 mV               :math:`V_\mathrm{peak}`                    Spike detection threshold (if ``Delta_T > 0``)
    ``V_reset``          -60 mV             :math:`V_\mathrm{reset}`                   Reset potential
    ``t_ref``            0 ms               :math:`t_\mathrm{ref}`                     Absolute refractory duration
    ``g_L``              30 nS              :math:`g_\mathrm{L}`                       Leak conductance
    ``C_m``              281 pF             :math:`C_\mathrm{m}`                       Membrane capacitance
    ``E_L``              -70.6 mV           :math:`E_\mathrm{L}`                       Leak reversal potential
    ``Delta_T``          2 mV               :math:`\Delta_T`                           Exponential slope factor
    ``tau_w``            144 ms             :math:`\tau_w`                             Adaptation time constant
    ``a``                4 nS               :math:`a`                                  Subthreshold adaptation
    ``b``                80.5 pA            :math:`b`                                  Spike-triggered adaptation increment
    ``V_th``             -50.4 mV           :math:`V_\mathrm{th}`                      Spike initiation threshold (in exponential term)
    ``tau_syn_ex``       0.2 ms             :math:`\tau_{\mathrm{syn,ex}}`             Excitatory exponential time constant
    ``tau_syn_in``       2.0 ms             :math:`\tau_{\mathrm{syn,in}}`             Inhibitory exponential time constant
    ``I_e``              0 pA               :math:`I_\mathrm{e}`                       Constant external current
    ``gsl_error_tol``    1e-6               (solver tolerance)                         RKF45 local error tolerance
    ``V_initializer``    Constant(-70.6 mV)                                            Membrane initializer
    ``I_ex_initializer`` Constant(0 pA)                                                Excitatory current initializer
    ``I_in_initializer`` Constant(0 pA)                                                Inhibitory current initializer
    ``w_initializer``    Constant(0 pA)                                                Adaptation current initializer
    ``spk_fun``          ReluGrad()                                                    Surrogate spike function
    ``spk_reset``        ``'hard'``                                                    Reset mode; hard reset matches NEST behavior
    ``ref_var``          ``False``                                                     If True, expose boolean refractory indicator
    ==================== ================== ========================================== =======================================================

    Attributes
    ----------
    V : brainstate.HiddenState
        Membrane potential. Shape: ``(*in_size,)``. Units: mV.
    I_ex : brainstate.HiddenState
        Excitatory synaptic current. Shape: ``(*in_size,)``. Units: pA.
    I_in : brainstate.HiddenState
        Inhibitory synaptic current. Shape: ``(*in_size,)``. Units: pA.
    w : brainstate.HiddenState
        Adaptation current. Shape: ``(*in_size,)``. Units: pA.
    refractory_step_count : brainstate.ShortTermState
        Remaining refractory steps. Shape: ``(*in_size,)``. Dtype: int32.
    integration_step : brainstate.ShortTermState
        Persistent RKF45 internal step size. Shape: ``(*in_size,)``. Units: ms.
    I_stim : brainstate.ShortTermState
        One-step delayed current buffer. Shape: ``(*in_size,)``. Units: pA.
    last_spike_time : brainstate.ShortTermState
        Last emitted spike time. Shape: ``(*in_size,)``. Units: ms.
        Updated to ``t + dt`` on spike emission.
    refractory : brainstate.ShortTermState, optional
        Boolean refractory indicator. Only exists if ``ref_var=True``.
        Shape: ``(*in_size,)``. Dtype: bool.

    Raises
    ------
    ValueError
        - If ``V_reset >= V_peak``
        - If ``Delta_T < 0``
        - If ``V_peak < V_th``
        - If ``C_m <= 0``
        - If ``t_ref < 0``
        - If any time constant ``<= 0``
        - If ``gsl_error_tol <= 0``
        - If ``(V_peak - V_th) / Delta_T`` is too large (overflow risk in exponential term)
        - If numerical instability detected (``V < -1e3`` or ``|w| > 1e6``)

    Notes
    -----
    **Implementation Details:**

    - **Adaptive integration:** RKF45 adjusts step size :math:`h` dynamically to meet error
      tolerance. Step size is persistent across simulation steps for efficiency.
    - **Refractory semantics:** During refractory, voltage is clamped to ``V_reset`` in the
      ODE RHS and ``dV/dt = 0``. This matches NEST exactly.
    - **Multiple spikes per step:** With ``t_ref = 0``, multiple spikes can occur within
      one simulation step. Each spike triggers reset and adaptation increment.
    - **Overflow protection:** Parameter validation checks that the exponential term
      :math:`\exp((V_{\text{peak}} - V_{th}) / \Delta_T)` does not overflow.
    - **Surrogate gradients:** For backpropagation, spike generation uses ``spk_fun``
      (default: ReLU gradient). Hard reset (``spk_reset='hard'``) stops gradient through
      reset, matching biological discontinuity.

    **Differences from other models:**

    - ``aeif_cond_exp``: Uses conductance-based synapses instead of current-based.
    - ``aeif_psc_alpha``: Uses alpha-function synapses instead of exponential.
    - ``aeif_psc_delta``: Uses delta-function (instantaneous) synapses.

    Examples
    --------
    Basic usage with constant input current:

    .. code-block:: python

        >>> import brainpy.state as bp
        >>> import saiunit as u
        >>> import brainstate
        >>>
        >>> # Create population of 100 AdEx neurons
        >>> neurons = bp.aeif_psc_exp(100, I_e=200 * u.pA)
        >>>
        >>> # Initialize states
        >>> with brainstate.environ.context(dt=0.1 * u.ms):
        ...     neurons.init_all_states()
        ...
        ...     # Run for 100 ms
        ...     spikes = []
        ...     for _ in range(1000):
        ...         spike = neurons.update()
        ...         spikes.append(spike)

    With synaptic input and refractory period:

    .. code-block:: python

        >>> # Create neurons with 2 ms refractory period
        >>> neurons = bp.aeif_psc_exp(
        ...     in_size=100,
        ...     t_ref=2.0 * u.ms,
        ...     tau_syn_ex=5.0 * u.ms,
        ...     tau_syn_in=10.0 * u.ms
        ... )
        >>>
        >>> with brainstate.environ.context(dt=0.1 * u.ms):
        ...     neurons.init_all_states()
        ...
        ...     # Add excitatory input (positive weights)
        ...     neurons.add_delta_input('exc', lambda: 100 * u.pA)
        ...
        ...     # Simulation step
        ...     spike = neurons.update(x=50 * u.pA)  # External current

    References
    ----------
    .. [1] Brette R, Gerstner W (2005). Adaptive exponential integrate-and-fire
           model as an effective description of neuronal activity.
           Journal of Neurophysiology, 94:3637-3642.
           DOI: https://doi.org/10.1152/jn.00686.2005
    .. [2] NEST source: ``models/aeif_psc_exp.h`` and
           ``models/aeif_psc_exp.cpp``.
    """

    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        V_peak: ArrayLike = 0.0 * u.mV,
        V_reset: ArrayLike = -60.0 * u.mV,
        t_ref: ArrayLike = 0.0 * u.ms,
        g_L: ArrayLike = 30.0 * u.nS,
        C_m: ArrayLike = 281.0 * u.pF,
        E_L: ArrayLike = -70.6 * u.mV,
        Delta_T: ArrayLike = 2.0 * u.mV,
        tau_w: ArrayLike = 144.0 * u.ms,
        a: ArrayLike = 4.0 * u.nS,
        b: ArrayLike = 80.5 * u.pA,
        V_th: ArrayLike = -50.4 * u.mV,
        tau_syn_ex: ArrayLike = 0.2 * u.ms,
        tau_syn_in: ArrayLike = 2.0 * u.ms,
        I_e: ArrayLike = 0.0 * u.pA,
        gsl_error_tol: ArrayLike = 1e-6,
        V_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
        I_ex_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
        I_in_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
        w_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.V_peak = braintools.init.param(V_peak, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.Delta_T = braintools.init.param(Delta_T, self.varshape)
        self.tau_w = braintools.init.param(tau_w, self.varshape)
        self.a = braintools.init.param(a, self.varshape)
        self.b = braintools.init.param(b, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.gsl_error_tol = gsl_error_tol

        self.V_initializer = V_initializer
        self.I_ex_initializer = I_ex_initializer
        self.I_in_initializer = I_in_initializer
        self.w_initializer = w_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate parameter consistency and numerical stability.

        Checks parameter constraints following NEST validation rules:
          - ``V_reset < V_peak``
          - ``Delta_T >= 0``
          - ``V_peak >= V_th``
          - ``C_m > 0``
          - ``t_ref >= 0``
          - All time constants ``> 0``
          - ``gsl_error_tol > 0``
          - Exponential term overflow guard: ``(V_peak - V_th) / Delta_T < log(max_float / 1e20)``

        Raises
        ------
        ValueError
            If any parameter constraint is violated.
        """
        v_reset = self.V_reset
        v_peak = self.V_peak
        v_th = self.V_th
        delta_t = self.Delta_T / u.mV

        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (v_reset, v_peak, v_th, delta_t)):
            return

        if np.any(v_reset >= v_peak):
            raise ValueError('Ensure that V_reset < V_peak .')
        if np.any(delta_t < 0.0):
            raise ValueError('Delta_T must be positive.')
        if np.any(v_peak < v_th):
            raise ValueError('V_peak >= V_th required.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Ensure that C_m > 0')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_w <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

        # Mirror NEST overflow guard for exponential term at spike time.
        validate_aeif_overflow(v_peak, v_th, delta_t)

[docs] def init_state(self, **kwargs): r"""Initialize persistent and short-term state variables. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() V = braintools.init.param(self.V_initializer, self.varshape) I_ex = braintools.init.param(self.I_ex_initializer, self.varshape) I_in = braintools.init.param(self.I_in_initializer, self.varshape) w = braintools.init.param(self.w_initializer, self.varshape) self.V = brainstate.HiddenState(V) self.I_ex = brainstate.HiddenState(I_ex) self.I_in = brainstate.HiddenState(I_in) self.w = brainstate.HiddenState(w) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient. Applies the surrogate gradient function ``spk_fun`` to a scaled voltage. The scaling maps the threshold region to a suitable range for the surrogate. Parameters ---------- V : ArrayLike, optional Membrane potential. Units: mV. If None, uses ``self.V.value``. Shape: ``(*in_size,)``. Returns ------- ArrayLike Differentiable spike output in [0, 1]. Shape: ``(*in_size,)``. Dtype: float. Values close to 1 indicate spike, close to 0 indicate silence. Notes ----- The voltage is scaled by ``(V - V_th) / (V_th - V_reset)`` before passing to the surrogate function. This normalizes the threshold crossing region for the surrogate gradient approximation. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, I_ex, I_in, w — ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, v_peak_detect — mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_peak)) delta_t_safe = u.math.where(self.Delta_T == 0.0 * u.mV, 1.0 * u.mV, self.Delta_T) exp_arg = u.math.clip((v_eff - self.V_th) / delta_t_safe, -500.0, 500.0) i_spike = self.g_L * self.Delta_T * u.math.exp(exp_arg) dV_raw = ( -self.g_L * (v_eff - self.E_L) + i_spike + state.I_ex - state.I_in - state.w + self.I_e + extra.i_stim ) / self.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) dI_ex = -state.I_ex / self.tau_syn_ex dI_in = -state.I_in / self.tau_syn_in dw = (self.a * (v_eff - self.E_L) - state.w) / self.tau_w return DotDict(V=dV, I_ex=dI_ex, I_in=dI_in, w=dw) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, and refractory handling. Parameters ---------- state : DotDict Keys: V, I_ex, I_in, w — ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, v_peak_detect. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ unstable = extra.unstable | jnp.any( accept & ((state.V < -1e3 * u.mV) | (state.w < -1e6 * u.pA) | (state.w > 1e6 * u.pA)) ) refr_accept = accept & (extra.r > 0) new_V = u.math.where(refr_accept, self.V_reset, state.V) spike_now = accept & (extra.r <= 0) & (new_V >= extra.v_peak_detect) spike_mask = extra.spike_mask | spike_now new_V = u.math.where(spike_now, self.V_reset, new_V) new_w = u.math.where(spike_now, state.w + self.b, state.w) r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r) new_state = DotDict({**state, 'V': new_V, 'w': new_w}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable}) return new_state, new_extra
[docs] def update(self, x=0.0 * u.pA): r"""Advance neuron state by one simulation timestep. Performs one simulation step of the adaptive exponential integrate-and-fire neuron using adaptive RKF45 integration. Handles refractory clamping, spike detection, reset, adaptation increment, and synaptic input application. The update sequence follows NEST semantics: 1. **Integrate ODEs** over :math:`[t, t+dt]` using adaptive RKF45: - Vectorized integration with adaptive step size - Inside integration: apply refractory clamp, detect spikes, reset voltage, increment adaptation, update refractory counter - Step size persists across simulation steps for efficiency 2. **Post-integration processing:** - Decrement refractory counter by 1 - Apply delta inputs (spike weights) to :math:`I_{ex}`, :math:`I_{in}` - Store external current :math:`x` into one-step delayed buffer :math:`I_{\text{stim}}` - Update ``last_spike_time`` for neurons that spiked 3. **Return spike tensor:** - Binary array indicating which neurons spiked during :math:`[t, t+dt]` Parameters ---------- x : ArrayLike, optional External current input. Units: pA. Default: 0.0 pA. Shape: scalar or broadcastable to ``(*in_size,)``. Combined with ``current_inputs`` and stored in ``I_stim`` for next step. Returns ------- jax.Array Binary spike tensor with dtype ``jnp.float64`` and shape ``self.V.value.shape``. A value of ``1.0`` indicates at least one internal spike event occurred during the integrated interval :math:`(t, t+dt]`. Raises ------ ValueError If numerical instability detected: ``V < -1e3`` or ``|w| > 1e6``. Notes ----- Integration is performed with an adaptive vectorized RKF45 loop, including in-loop spike/reset/adaptation events and optional multiple spikes per step. All arithmetic is unit-aware via ``saiunit.math``. See Also -------- init_state : Initialize state variables get_spike : Compute differentiable spike output """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV I_ex = self.I_ex.value # pA I_in = self.I_in.value # pA w = self.w.value # pA r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Spike detection threshold: V_peak if Delta_T > 0, else V_th. v_peak_detect = u.math.where(self.Delta_T > 0.0 * u.mV, self.V_peak, self.V_th) # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, I_ex=I_ex, I_in=I_in, w=w) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim, v_peak_detect=v_peak_detect, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V, I_ex, I_in, w = ode_state.V, ode_state.I_ex, ode_state.I_in, ode_state.w spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in aeif_psc_exp dynamics.' ) # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.I_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.I_in.value), label='w_in') # Apply synaptic spike inputs (current-based: direct addition in pA). I_ex = I_ex + w_ex I_in = I_in + w_in # Write back state. self.V.value = V self.I_ex.value = I_ex self.I_in.value = I_in self.w.value = w self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return u.math.asarray(spike_mask, dtype=dftype)