Source code for brainpy_state._nest.aeif_psc_alpha

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, validate_aeif_overflow, AdaptiveRungeKuttaStep

__all__ = [
    'aeif_psc_alpha',
]


class aeif_psc_alpha(NESTNeuron):
    r"""NEST-compatible adaptive exponential integrate-and-fire neuron with alpha-shaped postsynaptic currents.

    This model implements the adaptive exponential integrate-and-fire (AdEx) neuron with
    current-based synapses following alpha-function kinetics. It replicates the behavior of
    NEST's ``aeif_psc_alpha`` model, including adaptive Runge-Kutta-Fehlberg (RKF45) numerical
    integration, in-loop spike detection and reset, and NEST-compatible refractory handling.

    **1. Mathematical Model**

    **Membrane Dynamics**

    The subthreshold membrane potential :math:`V` evolves according to:

    .. math::

       C_m \frac{dV}{dt} = -g_L (V - E_L) + g_L \Delta_T \exp\left(\frac{V - V_{th}}{\Delta_T}\right)
                           - w + I_{ex} - I_{in} + I_e + I_{stim}

    where:

    - :math:`C_m` -- membrane capacitance
    - :math:`g_L` -- leak conductance
    - :math:`E_L` -- leak reversal potential
    - :math:`\Delta_T` -- exponential slope factor (spike sharpness)
    - :math:`V_{th}` -- spike initiation threshold
    - :math:`w` -- adaptation current
    - :math:`I_{ex}`, :math:`I_{in}` -- excitatory and inhibitory synaptic currents
    - :math:`I_e` -- constant external current
    - :math:`I_{stim}` -- time-varying external input (one-step delayed)

    The exponential term :math:`g_L \Delta_T \exp((V - V_{th})/\Delta_T)` causes rapid
    voltage acceleration near :math:`V_{th}`, producing spike initiation. Setting
    :math:`\Delta_T = 0` recovers the leaky integrate-and-fire (LIF) limit.

    **Adaptation Dynamics**

    The adaptation current :math:`w` provides spike-frequency adaptation and subthreshold
    coupling:

    .. math::

       \tau_w \frac{dw}{dt} = a(V - E_L) - w

    - Subthreshold adaptation: parameter :math:`a` couples :math:`w` to membrane potential
    - Spike-triggered adaptation: at each spike, :math:`w \leftarrow w + b`

    **Alpha-Function Synaptic Currents**

    Excitatory and inhibitory currents are modeled as alpha functions, each requiring
    two state variables:

    .. math::

       \frac{d\,dI_{ex}}{dt} = -\frac{dI_{ex}}{\tau_{syn,ex}}, \quad
       \frac{dI_{ex}}{dt} = dI_{ex} - \frac{I_{ex}}{\tau_{syn,ex}}

    .. math::

       \frac{d\,dI_{in}}{dt} = -\frac{dI_{in}}{\tau_{syn,in}}, \quad
       \frac{dI_{in}}{dt} = dI_{in} - \frac{I_{in}}{\tau_{syn,in}}

    Incoming spike weights :math:`w_j` (in pA) are split by sign and delivered as:

    .. math::

       dI_{ex} \leftarrow dI_{ex} + \frac{e}{\tau_{syn,ex}} \max(w_j, 0)

    .. math::

       dI_{in} \leftarrow dI_{in} + \frac{e}{\tau_{syn,in}} \max(-w_j, 0)

    where :math:`e = \exp(1)` provides the alpha-function normalization.

    **2. Spike Detection and Reset**

    **Threshold Crossing**

    Spike detection threshold depends on :math:`\Delta_T`:

    - If :math:`\Delta_T > 0`: spike when :math:`V \geq V_{peak}`
    - If :math:`\Delta_T = 0`: spike when :math:`V \geq V_{th}` (LIF-like)

    **Reset Mechanism**

    Upon spike detection:

    1. :math:`V \leftarrow V_{reset}`
    2. :math:`w \leftarrow w + b` (spike-triggered adaptation)
    3. Refractory counter set to :math:`\lceil t_{ref}/dt \rceil + 1` (if :math:`t_{ref} > 0`)

    Spike detection and reset occur *inside* the RKF45 integration substeps, allowing
    multiple spikes per simulation time step when :math:`t_{ref} = 0`.

    **3. Refractory Period Handling**

    During the refractory period (:math:`r > 0` steps remaining):

    - Membrane potential clamped: :math:`V_{eff} = V_{reset}`
    - Voltage derivative forced: :math:`dV/dt = 0`
    - Alpha currents and adaptation continue evolving normally

    After each time step, the refractory counter is decremented: :math:`r \leftarrow r - 1`.

    **4. Numerical Integration**

    The model uses adaptive Runge-Kutta-Fehlberg (RKF45) with local error control:

    - **Order**: 5th-order accurate solution with 4th-order error estimate
    - **Error tolerance**: controlled by ``gsl_error_tol`` (default :math:`10^{-6}`)
    - **Step size adaptation**: :math:`h_{new} = h \cdot \min(5, \max(0.2, 0.9 (\epsilon/\text{err})^{0.2}))`
    - **Minimum step**: :math:`h_{min} = 10^{-8}` ms to prevent stalling
    - **Persistent step size**: each neuron maintains its own integration step size across time

    The RKF45 Butcher tableau coefficients follow the standard formulation with stages
    :math:`k_1` through :math:`k_6`.

    **5. Update Sequence**

    Each simulation step processes state updates in this order:

    1. **Integration loop**: Integrate ODEs from :math:`t` to :math:`t + dt` using RKF45
       substeps, checking for spikes and applying resets within the loop
    2. **Refractory decrement**: After integration, decrement refractory counter once
    3. **Synaptic input delivery**: Add spike weights to :math:`dI_{ex}` and :math:`dI_{in}`
    4. **External current update**: Store current input :math:`x` into one-step-delayed buffer
       :math:`I_{stim}` (to be used in next step)

    Parameters
    ----------
    in_size : int, tuple of int
        Shape of the neuron population. Can be an integer (1D) or tuple (multi-dimensional).
    V_peak : ArrayLike, optional
        Spike detection threshold voltage. Default: ``0.0 * u.mV``.
        Used for threshold detection when :math:`\Delta_T > 0`.
    V_reset : ArrayLike, optional
        Reset potential after spike. Default: ``-60.0 * u.mV``.
    t_ref : ArrayLike, optional
        Absolute refractory period duration. Default: ``0.0 * u.ms``.
        During refractory period, :math:`V` is clamped to :math:`V_{reset}` and :math:`dV/dt = 0`.
    g_L : ArrayLike, optional
        Leak conductance. Default: ``30.0 * u.nS``.
    C_m : ArrayLike, optional
        Membrane capacitance. Default: ``281.0 * u.pF``.
        Determines membrane time constant :math:`\tau_m = C_m / g_L`.
    E_L : ArrayLike, optional
        Leak reversal potential. Default: ``-70.6 * u.mV``.
    Delta_T : ArrayLike, optional
        Exponential slope factor. Default: ``2.0 * u.mV``.
        Controls spike sharpness; set to 0 for LIF-like behavior.
    tau_w : ArrayLike, optional
        Adaptation time constant. Default: ``144.0 * u.ms``.
    a : ArrayLike, optional
        Subthreshold adaptation coupling. Default: ``4.0 * u.nS``.
        Couples adaptation current to membrane potential deviation from :math:`E_L`.
    b : ArrayLike, optional
        Spike-triggered adaptation increment. Default: ``80.5 * u.pA``.
        Added to :math:`w` at each spike.
    V_th : ArrayLike, optional
        Spike initiation threshold. Default: ``-50.4 * u.mV``.
        Appears in exponential term and as fallback spike threshold when :math:`\Delta_T = 0`.
    tau_syn_ex : ArrayLike, optional
        Excitatory synaptic alpha time constant. Default: ``0.2 * u.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory synaptic alpha time constant. Default: ``2.0 * u.ms``.
    I_e : ArrayLike, optional
        Constant external current. Default: ``0.0 * u.pA``.
    gsl_error_tol : ArrayLike, optional
        RKF45 local error tolerance. Default: ``1e-6``.
        Smaller values increase accuracy but require smaller integration steps.
    V_initializer : Callable, optional
        Membrane potential initializer. Default: ``Constant(-70.6 * u.mV)``.
    I_ex_initializer : Callable, optional
        Excitatory current initializer. Default: ``Constant(0.0 * u.pA)``.
    I_in_initializer : Callable, optional
        Inhibitory current initializer. Default: ``Constant(0.0 * u.pA)``.
    w_initializer : Callable, optional
        Adaptation current initializer. Default: ``Constant(0.0 * u.pA)``.
    spk_fun : Callable, optional
        Surrogate gradient function for differentiable spike generation.
        Default: ``ReluGrad()``.
    spk_reset : str, optional
        Spike reset mode: ``'soft'`` (subtract threshold) or ``'hard'`` (stop gradient).
        Default: ``'hard'`` (matches NEST behavior).
    ref_var : bool, optional
        If ``True``, expose a boolean ``refractory`` state variable indicating refractory status.
        Default: ``False``.
    name : str, optional
        Name of the neuron population.

    Parameter Mapping
    -----------------

    ==================== ================== ========================================== =====================================================
    **Parameter**        **Default**        **Math equivalent**                        **Description**
    ==================== ================== ========================================== =====================================================
    ``in_size``          (required)         —                                          Population shape
    ``V_peak``           0 mV               :math:`V_\mathrm{peak}`                    Spike detection threshold (if :math:`\Delta_T > 0`)
    ``V_reset``          -60 mV             :math:`V_\mathrm{reset}`                   Reset potential
    ``t_ref``            0 ms               :math:`t_\mathrm{ref}`                     Absolute refractory duration
    ``g_L``              30 nS              :math:`g_\mathrm{L}`                       Leak conductance
    ``C_m``              281 pF             :math:`C_\mathrm{m}`                       Membrane capacitance
    ``E_L``              -70.6 mV           :math:`E_\mathrm{L}`                       Leak reversal potential
    ``Delta_T``          2 mV               :math:`\Delta_T`                           Exponential slope factor
    ``tau_w``            144 ms             :math:`\tau_w`                             Adaptation time constant
    ``a``                4 nS               :math:`a`                                  Subthreshold adaptation coupling
    ``b``                80.5 pA            :math:`b`                                  Spike-triggered adaptation increment
    ``V_th``             -50.4 mV           :math:`V_\mathrm{th}`                      Spike initiation threshold
    ``tau_syn_ex``       0.2 ms             :math:`\tau_{\mathrm{syn,ex}}`             Excitatory alpha time constant
    ``tau_syn_in``       2.0 ms             :math:`\tau_{\mathrm{syn,in}}`             Inhibitory alpha time constant
    ``I_e``              0 pA               :math:`I_\mathrm{e}`                       Constant external current
    ``gsl_error_tol``    1e-6               —                                          RKF45 local error tolerance
    ``V_initializer``    Constant(-70.6 mV) —                                          Membrane initializer
    ``I_ex_initializer`` Constant(0 pA)     —                                          Excitatory current initializer
    ``I_in_initializer`` Constant(0 pA)     —                                          Inhibitory current initializer
    ``w_initializer``    Constant(0 pA)     —                                          Adaptation current initializer
    ``spk_fun``          ReluGrad()         —                                          Surrogate spike function
    ``spk_reset``        ``'hard'``         —                                          Reset mode (hard matches NEST)
    ``ref_var``          ``False``          —                                          Expose boolean refractory indicator
    ==================== ================== ========================================== =====================================================

    Attributes
    ----------
    V : brainstate.HiddenState
        Membrane potential, shape ``(*in_size,)`` with unit mV.
    dI_ex : brainstate.ShortTermState
        Excitatory alpha auxiliary state (derivative of :math:`I_{ex}`), unit pA/ms.
    I_ex : brainstate.HiddenState
        Excitatory synaptic current, unit pA.
    dI_in : brainstate.ShortTermState
        Inhibitory alpha auxiliary state (derivative of :math:`I_{in}`), unit pA/ms.
    I_in : brainstate.HiddenState
        Inhibitory synaptic current, unit pA.
    w : brainstate.HiddenState
        Adaptation current, unit pA.
    refractory_step_count : brainstate.ShortTermState
        Remaining refractory time steps (int32 array).
    integration_step : brainstate.ShortTermState
        Current RKF45 integration step size, unit ms. Persists across simulation steps.
    I_stim : brainstate.ShortTermState
        One-step-delayed external current buffer, unit pA.
    last_spike_time : brainstate.ShortTermState
        Time of last spike emission, unit ms. Updated to :math:`t + dt` on spike.
    refractory : brainstate.ShortTermState, optional
        Boolean refractory indicator (only if ``ref_var=True``).

    Raises
    ------
    ValueError
        - If :math:`V_{reset} \geq V_{peak}`
        - If :math:`\Delta_T < 0`
        - If :math:`V_{peak} < V_{th}`
        - If :math:`C_m \leq 0`
        - If :math:`t_{ref} < 0`
        - If any time constant :math:`\leq 0`
        - If ``gsl_error_tol`` :math:`\leq 0`
        - If :math:`(V_{peak} - V_{th})/\Delta_T` would cause exponential overflow
        - If numerical instability detected during integration (:math:`V < -1000` mV or
          :math:`|w| > 10^6` pA)

    Notes
    -----
    **NEST Compatibility**

    - Replicates NEST ``aeif_psc_alpha`` dynamics including RKF45 integration and in-loop
      spike handling
    - Default parameter values match NEST defaults (converted to SI units)
    - Refractory clamping follows NEST semantics: :math:`V_{eff} = V_{reset}` during
      refractory, with :math:`dV/dt = 0`

    **Numerical Considerations**


    - Maximum iteration limit: 100,000 substeps per time step (prevents infinite loops)
    - Minimum step size: :math:`h_{min} = 10^{-8}` ms
    - Voltage capping during integration: :math:`V_{eff} = \min(V, V_{peak})` outside
      refractory period to prevent exponential overflow
    - Overflow protection: validates that :math:`\exp((V_{peak} - V_{th})/\Delta_T)`
      remains within floating-point range

    **Multiple Spikes Per Step**

    With :math:`t_{ref} = 0` (default), a neuron can spike multiple times within a single
    simulation step. The internal adaptation :math:`w` accumulates all spike-triggered
    increments :math:`b`, but the returned spike tensor is binary (0 or 1) per step.

    **Surrogate Gradients**

    The ``spk_fun`` parameter controls backpropagation through spikes for gradient-based
    learning. The surrogate function approximates the derivative of the Heaviside step
    function during backward passes.

    Examples
    --------
    Basic usage with default parameters:

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> import brainstate as bs
        >>>
        >>> # Create a population of 100 AdEx neurons
        >>> neuron = bst.aeif_psc_alpha(100)
        >>>
        >>> # Initialize states
        >>> with bs.environ.context(dt=0.1 * u.ms):
        ...     neuron.init_all_states()
        >>>
        >>> # Simulate one step with external current
        >>> with bs.environ.context(dt=0.1 * u.ms):
        ...     spikes = neuron.update(x=500 * u.pA)
        >>> spikes.shape
        (100,)

    Custom parameters for fast-spiking interneuron:

    .. code-block:: python

        >>> # Fast-spiking configuration
        >>> fs_neuron = bst.aeif_psc_alpha(
        ...     in_size=50,
        ...     C_m=150 * u.pF,
        ...     g_L=20 * u.nS,
        ...     tau_w=30 * u.ms,
        ...     a=0 * u.nS,  # Minimal subthreshold adaptation
        ...     b=20 * u.pA,  # Small spike-triggered adaptation
        ...     V_th=-52 * u.mV,
        ...     Delta_T=1.5 * u.mV,
        ...     tau_syn_ex=0.5 * u.ms,
        ...     tau_syn_in=1.0 * u.ms,
        ... )

    Connecting to upstream spike sources:

    .. code-block:: python

        >>> import brainevent as be
        >>>
        >>> # Create presynaptic spike generator
        >>> spike_gen = bst.PoissonSpike(100, freq=10 * u.Hz)
        >>>
        >>> # Create postsynaptic AdEx neurons
        >>> neurons = bst.aeif_psc_alpha(50)
        >>>
        >>> # Create projection with alpha synapses
        >>> proj = be.nn.FixedProb(
        ...     pre=spike_gen,
        ...     post=neurons,
        ...     prob=0.2,
        ...     weight=50.0,  # pA per spike
        ... )

    See Also
    --------
    aeif_cond_alpha : Conductance-based AdEx with alpha synapses
    aeif_psc_exp : AdEx with exponential postsynaptic currents
    aeif_psc_delta : AdEx with delta-function synaptic currents
    iaf_psc_alpha : Leaky integrate-and-fire with alpha currents (set ``Delta_T=0``)

    References
    ----------
    .. [1] Brette R, Gerstner W (2005). Adaptive exponential integrate-and-fire
           model as an effective description of neuronal activity.
           Journal of Neurophysiology, 94:3637-3642.
           DOI: https://doi.org/10.1152/jn.00686.2005
    .. [2] Gerstner W, Kistler WM, Naud R, Paninski L (2014). Neuronal Dynamics:
           From Single Neurons to Networks and Models of Cognition.
           Cambridge University Press. Chapter 6.
    .. [3] NEST Simulator Documentation. ``aeif_psc_alpha`` model.
           https://nest-simulator.readthedocs.io/
    .. [4] NEST source code: ``models/aeif_psc_alpha.h`` and ``models/aeif_psc_alpha.cpp``.
    """

    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

[docs] def __init__( self, in_size: Size, V_peak: ArrayLike = 0.0 * u.mV, V_reset: ArrayLike = -60.0 * u.mV, t_ref: ArrayLike = 0.0 * u.ms, g_L: ArrayLike = 30.0 * u.nS, C_m: ArrayLike = 281.0 * u.pF, E_L: ArrayLike = -70.6 * u.mV, Delta_T: ArrayLike = 2.0 * u.mV, tau_w: ArrayLike = 144.0 * u.ms, a: ArrayLike = 4.0 * u.nS, b: ArrayLike = 80.5 * u.pA, V_th: ArrayLike = -50.4 * u.mV, tau_syn_ex: ArrayLike = 0.2 * u.ms, tau_syn_in: ArrayLike = 2.0 * u.ms, I_e: ArrayLike = 0.0 * u.pA, gsl_error_tol: ArrayLike = 1e-6, V_initializer: Callable = braintools.init.Constant(-70.6 * u.mV), I_ex_initializer: Callable = braintools.init.Constant(0.0 * u.pA), I_in_initializer: Callable = braintools.init.Constant(0.0 * u.pA), w_initializer: Callable = braintools.init.Constant(0.0 * u.pA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'hard', ref_var: bool = False, name: str = None, ): r"""Initialize the aeif_psc_alpha neuron model. See class docstring for detailed parameter descriptions. """ super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) self.V_peak = braintools.init.param(V_peak, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.t_ref = braintools.init.param(t_ref, self.varshape) self.g_L = braintools.init.param(g_L, self.varshape) self.C_m = braintools.init.param(C_m, self.varshape) self.E_L = braintools.init.param(E_L, self.varshape) self.Delta_T = braintools.init.param(Delta_T, self.varshape) self.tau_w = braintools.init.param(tau_w, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape) self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape) self.I_e = braintools.init.param(I_e, self.varshape) self.gsl_error_tol = gsl_error_tol self.V_initializer = V_initializer self.I_ex_initializer = I_ex_initializer self.I_in_initializer = I_in_initializer self.w_initializer = w_initializer self.ref_var = ref_var self._validate_parameters() self.integrator = AdaptiveRungeKuttaStep( method='RKF45', vf=self._vector_field, event_fn=self._event_fn, min_h=self._MIN_H, max_iters=self._MAX_ITERS, atol=self.gsl_error_tol, dt=brainstate.environ.get_dt() ) # other variable ditype = brainstate.environ.ditype() dt = brainstate.environ.get_dt() self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
def _validate_parameters(self): r"""Validate parameter constraints and check for numerical overflow conditions. Raises ------ ValueError If any of the following conditions are violated: - :math:`V_{reset} < V_{peak}` (reset must be below spike threshold) - :math:`\Delta_T \geq 0` (slope factor must be non-negative) - :math:`V_{peak} \geq V_{th}` (detection threshold must exceed initiation) - :math:`C_m > 0` (capacitance must be positive) - :math:`t_{ref} \geq 0` (refractory time cannot be negative) - :math:`\tau_{syn,ex}, \tau_{syn,in}, \tau_w > 0` (time constants must be positive) - ``gsl_error_tol`` :math:`> 0` (tolerance must be positive) - :math:`\exp((V_{peak} - V_{th})/\Delta_T)` within float64 range (prevents overflow at spike time) Notes ----- The overflow check mirrors NEST's validation: it ensures the exponential term :math:`g_L \Delta_T \exp((V_{peak} - V_{th})/\Delta_T)` remains computable in float64 precision. Uses threshold :math:`\log(\text{float64}_{\max} / 10^{20})` to provide safety margin. """ v_reset = self.V_reset v_peak = self.V_peak v_th = self.V_th delta_t = self.Delta_T / u.mV # Skip validation when parameters are JAX tracers (e.g. during jit). if any(is_tracer(v) for v in (v_reset, v_peak, v_th, delta_t)): return if np.any(v_reset >= v_peak): raise ValueError('Ensure that: V_reset < V_peak .') if np.any(delta_t < 0.0): raise ValueError('Delta_T must be positive.') if np.any(v_peak < v_th): raise ValueError('V_peak >= V_th required.') if np.any(self.C_m <= 0.0 * u.pF): raise ValueError('Capacitance must be strictly positive.') if np.any(self.t_ref < 0.0 * u.ms): raise ValueError('Refractory time cannot be negative.') if np.any(self.tau_syn_ex <= 0.0 * u.ms): raise ValueError('All time constants must be strictly positive.') if np.any(self.tau_syn_in <= 0.0 * u.ms): raise ValueError('All time constants must be strictly positive.') if np.any(self.tau_w <= 0.0 * u.ms): raise ValueError('All time constants must be strictly positive.') if np.any(self.gsl_error_tol <= 0.0): raise ValueError('The gsl_error_tol must be strictly positive.') # Mirror NEST overflow guard for exponential term at spike time. validate_aeif_overflow(v_peak, v_th, delta_t)
[docs] def init_state(self, **kwargs): r"""Initialize all state variables. Creates and initializes the membrane potential, synaptic currents, adaptation current, refractory counters, integration step size, and stimulus buffer. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() I_ex = braintools.init.param(self.I_ex_initializer, self.varshape) I_in = braintools.init.param(self.I_in_initializer, self.varshape) V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros(self.varshape, dtype=V.dtype) * (u.pA / u.ms) w = braintools.init.param(self.w_initializer, self.varshape) self.dI_ex = brainstate.ShortTermState(zeros) self.dI_in = brainstate.ShortTermState(zeros) self.I_ex = brainstate.HiddenState(I_ex) self.I_in = brainstate.HiddenState(I_in) self.V = brainstate.HiddenState(V) self.w = brainstate.HiddenState(w) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient. Applies the surrogate spike function to the scaled membrane potential for gradient-based learning. This method is used for backpropagation and does not affect the internal spike detection logic (which uses hard threshold crossing during integration). Parameters ---------- V : ArrayLike, optional Membrane potential array with unit mV. If ``None``, uses the current ``self.V.value``. Shape: ``(*in_size,)``. Returns ------- spike : Array Differentiable spike output with shape matching ``V``. Values are continuous in the forward pass (soft spikes) but use surrogate gradients in the backward pass. Typically in range [0, 1] depending on surrogate function. Notes ----- The voltage is scaled before applying the surrogate function: .. math:: v_{scaled} = \\frac{V - V_{th}}{V_{th} - V_{reset}} This normalization ensures the surrogate function operates in a consistent range regardless of the specific voltage parameters. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, dI_ex, I_ex, dI_in, I_in, w -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, v_peak_detect -- mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_peak)) delta_t_safe = u.math.where(self.Delta_T == 0.0 * u.mV, 1.0 * u.mV, self.Delta_T) exp_arg = u.math.clip((v_eff - self.V_th) / delta_t_safe, -500.0, 500.0) i_spike = self.g_L * self.Delta_T * u.math.exp(exp_arg) dV_raw = ( -self.g_L * (v_eff - self.E_L) + i_spike + state.I_ex - state.I_in - state.w + self.I_e + extra.i_stim ) / self.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) ddI_ex = -state.dI_ex / self.tau_syn_ex dI_ex_dt = state.dI_ex - state.I_ex / self.tau_syn_ex ddI_in = -state.dI_in / self.tau_syn_in dI_in_dt = state.dI_in - state.I_in / self.tau_syn_in dw = (self.a * (v_eff - self.E_L) - state.w) / self.tau_w return DotDict(V=dV, dI_ex=ddI_ex, I_ex=dI_ex_dt, dI_in=ddI_in, I_in=dI_in_dt, w=dw) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, and refractory handling. Parameters ---------- state : DotDict Keys: V, dI_ex, I_ex, dI_in, I_in, w -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, v_peak_detect. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ unstable = extra.unstable | jnp.any( accept & ((state.V < -1e3 * u.mV) | (state.w < -1e6 * u.pA) | (state.w > 1e6 * u.pA)) ) refr_accept = accept & (extra.r > 0) new_V = u.math.where(refr_accept, self.V_reset, state.V) spike_now = accept & (extra.r <= 0) & (new_V >= extra.v_peak_detect) spike_mask = extra.spike_mask | spike_now new_V = u.math.where(spike_now, self.V_reset, new_V) new_w = u.math.where(spike_now, state.w + self.b, state.w) r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r) new_state = DotDict({**state, 'V': new_V, 'w': new_w}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable}) return new_state, new_extra
[docs] def update(self, x=0.0 * u.pA): r"""Advance the neuron state by one simulation time step. Performs adaptive RKF45 integration of membrane, synaptic, and adaptation dynamics over the interval :math:`[t, t+dt]`, with in-loop spike detection, reset, and refractory handling matching NEST semantics. Parameters ---------- x : ArrayLike, optional External current input at the current time step, with unit pA. Shape must be broadcastable to ``(*in_size,)``. Default: ``0.0 * u.pA``. This input is stored in the one-step-delayed buffer ``I_stim`` and will be used in the *next* time step's dynamics (matching NEST input handling). Returns ------- spike : Array Binary spike indicator with shape ``(*in_size,)``, dtype float. Value is ``1.0`` where at least one spike occurred during the integration interval, ``0.0`` otherwise. Note: With ``t_ref=0``, neurons may spike multiple times within the step, but the returned tensor is binary per neuron per step. Internal adaptation dynamics accumulate all spike-triggered increments. Notes ----- **Integration Process** 1. **Adaptive RKF45 loop**: Starting from current state at time :math:`t`, integrate ODEs using RKF45 with adaptive step sizing until reaching :math:`t + dt`. - Each substep computes 6 stages (:math:`k_1` through :math:`k_6`) - Error estimate: :math:`err = \max|y_5 - y_4|` - Step acceptance: if :math:`err \leq atol` or :math:`h \leq h_{min}` - Step size update: :math:`h_{new} = h \cdot \min(5, \max(0.2, 0.9(atol/err)^{0.2}))` 2. **In-loop spike handling**: After each accepted substep, check if :math:`V \geq V_{peak}` (or :math:`V \geq V_{th}` if :math:`\Delta_T=0`). If spike detected: - Reset: :math:`V \leftarrow V_{reset}` - Adaptation jump: :math:`w \leftarrow w + b` - Refractory counter: :math:`r \leftarrow \lceil t_{ref}/dt \rceil + 1` (if enabled) 3. **Post-integration cleanup**: - Decrement refractory counter: :math:`r \leftarrow r - 1` (if :math:`r > 0`) - Deliver synaptic inputs: add spike weights to :math:`dI_{ex}` and :math:`dI_{in}` - Store external input: :math:`I_{stim} \leftarrow x` (for next step) - Update spike time: :math:`t_{spike} \leftarrow t + dt` (where spikes occurred) **Refractory Clamping** During refractory period (:math:`r > 0`): - Effective voltage: :math:`V_{eff} = V_{reset}` - Voltage derivative: :math:`dV/dt = 0` - All other state variables evolve normally **Voltage Capping** Outside refractory period, effective voltage is capped to prevent exponential overflow: :math:`V_{eff} = \min(V, V_{peak})`. **Numerical Stability** - Raises ``ValueError`` if :math:`V < -1000` mV (indicates divergence) - Raises ``ValueError`` if :math:`|w| > 10^6` pA (adaptation overflow) - Maximum iteration limit: 100,000 substeps per time step """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV dI_ex = self.dI_ex.value # pA/ms I_ex = self.I_ex.value # pA dI_in = self.dI_in.value # pA/ms I_in = self.I_in.value # pA w = self.w.value # pA r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Spike detection threshold: V_peak if Delta_T > 0, else V_th. v_peak_detect = u.math.where(self.Delta_T > 0.0 * u.mV, self.V_peak, self.V_th) # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, dI_ex=dI_ex, I_ex=I_ex, dI_in=dI_in, I_in=I_in, w=w) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim, v_peak_detect=v_peak_detect, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V, dI_ex, I_ex = ode_state.V, ode_state.dI_ex, ode_state.I_ex dI_in, I_in, w = ode_state.dI_in, ode_state.I_in, ode_state.w spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in aeif_psc_alpha dynamics.' ) # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.I_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.I_in.value), label='w_in') pscon_ex = np.e / self.tau_syn_ex # 1/ms pscon_in = np.e / self.tau_syn_in # 1/ms # Apply synaptic spike inputs. dI_ex = dI_ex + pscon_ex * w_ex # pA/ms + 1/ms * pA = pA/ms dI_in = dI_in + pscon_in * w_in # pA/ms + 1/ms * pA = pA/ms # Write back state. self.V.value = V self.dI_ex.value = dI_ex self.I_ex.value = I_ex self.dI_in.value = dI_in self.I_in.value = I_in self.w.value = w self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return u.math.asarray(spike_mask, dtype=dftype)