Source code for brainpy_state._nest.iaf_psc_alpha

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer

__all__ = [
    'iaf_psc_alpha',
]


class iaf_psc_alpha(NESTNeuron):
    r"""NEST-compatible ``iaf_psc_alpha`` neuron model.

    Description
    -----------
    ``iaf_psc_alpha`` is a current-based leaky integrate-and-fire neuron with
    hard threshold/reset, fixed absolute refractory period, and alpha-shaped
    excitatory/inhibitory current kernels. The implementation mirrors NEST
    ``models/iaf_psc_alpha.{h,cpp}`` update order and propagator formulas.

    **1. Continuous-Time Dynamics**

    The membrane dynamics are

    .. math::

       \frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m} + \frac{I_\text{syn} + I_e}{C_m}

    with :math:`I_\text{syn} = I_{\text{syn,ex}} + I_{\text{syn,in}}`.

    Each alpha current channel is represented by a two-state linear system:

    .. math::

       \frac{d\,dI_X}{dt} = -\frac{dI_X}{\tau_{\text{syn},X}}, \qquad
       \frac{dI_X}{dt} = dI_X - \frac{I_X}{\tau_{\text{syn},X}},
       \quad X \in \{\mathrm{ex}, \mathrm{in}\}.

    This is equivalent to the normalized alpha kernel

    .. math::

       i_X(t) = \frac{e}{\tau_{\text{syn},X}}\, t\, e^{-t/\tau_{\text{syn},X}} \Theta(t),

    which peaks at 1 when :math:`t=\tau_{\text{syn},X}`. Incoming spike weight
    :math:`w` (pA) is split by sign so :math:`w_+=\max(w,0)` drives excitatory
    state and :math:`w_-=\min(w,0)` drives inhibitory state.

    **2. Exact Discrete Propagator and NEST Update Order**

    For fixed step :math:`h=dt`, exact linear propagation is applied to
    :math:`y_3=V_m-E_L`, synaptic states, and a one-step delayed current buffer
    :math:`y_0`:

    .. math::

       dI_{X,n+1} = P_{11,X}\, dI_{X,n} + \frac{e}{\tau_{\text{syn},X}} w_{X,n},

    .. math::

       I_{X,n+1} = P_{21,X}\, dI_{X,n} + P_{22,X}\, I_{X,n},

    .. math::

       y_{3,n+1} = y_{3,n} + \big(e^{-h/\tau_m}-1\big) y_{3,n}
       + P_{30}(y_{0,n} + I_e)
       + \sum_{X \in \{\mathrm{ex},\mathrm{in}\}}
       \left(P_{31,X} dI_{X,n} + P_{32,X} I_{X,n}\right).

    Here :math:`P_{11,X}=P_{22,X}=e^{-h/\tau_{\text{syn},X}}`,
    :math:`P_{21,X}=h\,e^{-h/\tau_{\text{syn},X}}`, and
    :math:`P_{30}=\tau_m(1-e^{-h/\tau_m})/C_m`.

    Internal state (NEST notation):

    - :math:`y_0` -- buffered external current for next step,
    - :math:`dI_{ex}, I_{ex}` and :math:`dI_{in}, I_{in}` -- alpha-kernel states,
    - :math:`y_3 = V_m - E_L`,
    - :math:`r` -- refractory countdown in grid steps.

    Per-step update order:

    1. Update membrane potential if not refractory.
    2. Update synaptic alpha states.
    3. Add arriving spike input to :math:`dI_{ex}` / :math:`dI_{in}`.
    4. Perform threshold test, reset, refractory assignment, spike emission.
    5. Store buffered external current for the next step.

    **3. Near-Singular Regime :math:`\tau_m \approx \tau_{\text{syn}}`**

    Direct formulas for :math:`P_{31}` and :math:`P_{32}` contain divisions by
    :math:`(\tau_m-\tau_{\text{syn}})`, which are ill-conditioned near
    equality. The helper :meth:`_alpha_propagator_p31_p32` follows NEST's
    ``IAFPropagatorAlpha`` behavior and switches to stable limits:

    .. math::

       P_{32}^{\mathrm{sing}} = \frac{h}{C_m} e^{-h/\tau_m}, \qquad
       P_{31}^{\mathrm{sing}} = \frac{h^2}{2C_m} e^{-h/\tau_m},

    preventing cancellation/underflow artifacts around
    :math:`\tau_m=\tau_{\text{syn}}`.

    **4. Assumptions, Constraints, and Computational Implications**

    - ``C_m > 0``, ``tau_m > 0``, ``tau_syn_ex > 0``, ``tau_syn_in > 0``,
      ``t_ref >= 0``, and ``V_reset < V_th`` are enforced at construction.
    - ``update(x=...)`` uses one-step delayed current buffering (NEST
      ring-buffer semantics): current provided at step ``n`` contributes at
      step ``n+1`` through ``y0``.
    - The update path is vectorized over ``self.varshape`` and performs
      :math:`O(\prod \mathrm{varshape})` floating-point work per call.
    - Internal coefficient math is in ``float64`` via NumPy conversion, while
      exposed states remain in BrainUnit quantities.

    Parameters
    ----------
    in_size : Size
        Population shape specification. All per-neuron parameters are
        broadcast to ``self.varshape`` derived from ``in_size``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV; scalar or array broadcastable to
        ``self.varshape``. Default is ``-70. * u.mV``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF; broadcastable to
        ``self.varshape`` and strictly positive. Default is ``250. * u.pF``.
    tau_m : ArrayLike, optional
        Membrane time constant :math:`\tau_m` in ms; broadcastable and strictly
        positive. Default is ``10. * u.ms``.
    t_ref : ArrayLike, optional
        Absolute refractory period :math:`t_{ref}` in ms; broadcastable and
        nonnegative. Converted to integer step counts by ``ceil(t_ref / dt)``.
        Default is ``2. * u.ms``.
    V_th : ArrayLike, optional
        Spike threshold :math:`V_{th}` in mV; broadcastable to ``self.varshape``.
        Default is ``-55. * u.mV``.
    V_reset : ArrayLike, optional
        Reset potential :math:`V_{reset}` in mV; broadcastable and must satisfy
        ``V_reset < V_th`` elementwise. Default is ``-70. * u.mV``.
    tau_syn_ex : ArrayLike, optional
        Excitatory alpha time constant :math:`\tau_{\text{syn,ex}}` in ms;
        broadcastable and strictly positive. Default is ``2. * u.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory alpha time constant :math:`\tau_{\text{syn,in}}` in ms;
        broadcastable and strictly positive. Default is ``2. * u.ms``.
    I_e : ArrayLike, optional
        Constant injected current :math:`I_e` in pA; scalar or array
        broadcastable to ``self.varshape``. Default is ``0. * u.pA``.
    V_min : ArrayLike or None, optional
        Optional lower voltage clamp :math:`V_{min}` in mV. When provided,
        membrane candidate update is clipped with ``max(V, V_min)`` before
        thresholding. ``None`` disables clipping. Default is ``None``.
    V_initializer : Callable, optional
        Initializer for membrane state ``V``. Called by :meth:`init_state`.
        Default is ``braintools.init.Constant(-70. * u.mV)``.
    spk_fun : Callable, optional
        Surrogate spike nonlinearity used by :meth:`get_spike`. Default is
        ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` matches NEST hard reset semantics. Default is ``'hard'``.
    ref_var : bool, optional
        If ``True``, allocates boolean state ``self.refractory`` for external
        inspection of refractory condition. Default is ``False``.
    name : str or None, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 17 27 14 16 36

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines neuron population shape ``self.varshape``.
       * - ``E_L``
         - ArrayLike, broadcastable to ``self.varshape`` (mV)
         - ``-70. * u.mV``
         - :math:`E_L`
         - Resting membrane potential.
       * - ``C_m``
         - ArrayLike, broadcastable (pF), ``> 0``
         - ``250. * u.pF``
         - :math:`C_m`
         - Membrane capacitance used in subthreshold integration.
       * - ``tau_m``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``10. * u.ms``
         - :math:`\tau_m`
         - Leak time constant in membrane propagator.
       * - ``t_ref``
         - ArrayLike, broadcastable (ms), ``>= 0``
         - ``2. * u.ms``
         - :math:`t_{ref}`
         - Absolute refractory duration.
       * - ``V_th`` and ``V_reset``
         - ArrayLike, broadcastable (mV), with ``V_reset < V_th``
         - ``-55. * u.mV``, ``-70. * u.mV``
         - :math:`V_{th}`, :math:`V_{reset}`
         - Threshold and post-spike reset levels.
       * - ``tau_syn_ex`` and ``tau_syn_in``
         - ArrayLike, broadcastable (ms), each ``> 0``
         - ``2. * u.ms``
         - :math:`\tau_{\text{syn,ex}}`, :math:`\tau_{\text{syn,in}}`
         - Alpha kernel time constants for excitatory/inhibitory currents.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * u.pA``
         - :math:`I_e`
         - Constant external current added every step.
       * - ``V_min``
         - ArrayLike broadcastable (mV) or ``None``
         - ``None``
         - :math:`V_{min}`
         - Optional lower bound on membrane candidate update.
       * - ``V_initializer``
         - Callable
         - ``Constant(-70. * u.mV)``
         - --
         - Initializer used for membrane state ``V``.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate spike function returned by :meth:`update`.
       * - ``spk_reset``
         - str
         - ``'hard'``
         - --
         - Reset mode inherited from :class:`~brainpy_state._base.Neuron`.
       * - ``ref_var``
         - bool
         - ``False``
         - --
         - Allocate boolean state ``self.refractory`` when enabled.
       * - ``name``
         - str | None
         - ``None``
         - --
         - Optional node name.

    Raises
    ------
    ValueError
        If parameter constraints are violated: ``C_m <= 0``, ``tau_m <= 0``,
        ``tau_syn_ex <= 0``, ``tau_syn_in <= 0``, ``t_ref < 0``, or
        ``V_reset >= V_th``.
    TypeError
        If provided quantities are not unit-compatible with expected units
        (mV, ms, pF, pA) during conversion/broadcasting.
    KeyError
        At runtime, if required simulation context entries (for example ``t``
        or ``dt``) are missing when :meth:`update` is called.
    AttributeError
        If :meth:`update` is called before :meth:`init_state` creates required
        state holders.

    Notes
    -----

    - State variables are ``V``, ``I_syn_ex``, ``I_syn_in``, ``dI_syn_ex``,
      ``dI_syn_in``, ``y0``, ``refractory_step_count``, and ``last_spike_time``.
      ``refractory`` is added only when ``ref_var=True``.
    - Spike weights from ``sum_delta_inputs`` are interpreted in pA:
      positive values are excitatory and negative values are inhibitory.
    - ``update(x=...)`` stores ``x`` into ``y0`` for the next step, matching
      NEST current-event buffering semantics.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_alpha(in_size=(2,), I_e=220.0 * u.pA)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=1.0 * u.ms):
       ...         spk = neu.update()
       ...     _ = spk.shape

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_alpha(in_size=1, tau_syn_ex=1.5 * u.ms)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         _ = neu.update(x=200.0 * u.pA)
       ...     with brainstate.environ.context(t=0.1 * u.ms):
       ...         spk_next = neu.update()
       ...     _ = spk_next

    References
    ----------
    .. [1] NEST source: ``models/iaf_psc_alpha.h`` and
           ``models/iaf_psc_alpha.cpp``.
    .. [2] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear
           systems with applications to neuronal modeling. Biological Cybernetics
           81:381-402. DOI: https://doi.org/10.1007/s004220050570
    .. [3] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State space
           analysis of synchronous spiking in cortical neural networks.
           Neurocomputing 38-40:565-571.
           DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
    .. [4] Morrison A, Straube S, Plesser HE, Diesmann M (2007). Exact
           subthreshold integration with continuous spike times in discrete time
           neural network simulations. Neural Computation 19(1):47-79.
           DOI: https://doi.org/10.1162/neco.2007.19.1.47
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        tau_m: ArrayLike = 10. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -70. * u.mV,
        tau_syn_ex: ArrayLike = 2. * u.ms,
        tau_syn_in: ArrayLike = 2. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        V_min: ArrayLike = None,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.V_min = None if V_min is None else braintools.init.param(V_min, self.varshape)
        self.V_initializer = V_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        # Precompute refractory step count (matches aeif_cond_alpha pattern).
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be > 0.')
        if np.any(self.tau_m <= 0.0 * u.ms):
            raise ValueError('Membrane time constant must be > 0.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All synaptic time constants must be > 0.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError("The refractory time t_ref can't be negative.")
        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')

[docs] def init_state(self, **kwargs): r"""Initialize runtime states for membrane, synapses, and refractoriness. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If initializers cannot broadcast to ``self.varshape``. TypeError If initializer outputs are incompatible with expected unit/array conversions for voltage, current, or integer refractory states. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros(self.varshape, dtype=V.dtype) self.V = brainstate.HiddenState(V) self.I_syn_ex = brainstate.ShortTermState(zeros * u.pA) self.I_syn_in = brainstate.ShortTermState(zeros * u.pA) self.dI_syn_ex = brainstate.ShortTermState(zeros * (u.pA / u.ms)) self.dI_syn_in = brainstate.ShortTermState(zeros * (u.pA / u.ms)) self.y0 = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory) # Pre-compute propagator coefficients (constant for a given dt). self._precompute_propagators()
def _precompute_propagators(self): """Pre-compute NEST propagator coefficients from dt and model parameters. Called once during ``init_state`` so that ``update`` never needs to call ``float(dt)`` or recompute exponentials each step. """ dt = brainstate.environ.get_dt() h = float(u.math.asarray(dt / u.ms)) tau_ex = np.asarray(u.get_mantissa(self.tau_syn_ex / u.ms), dtype=np.float64) tau_in = np.asarray(u.get_mantissa(self.tau_syn_in / u.ms), dtype=np.float64) tau_m = np.asarray(u.get_mantissa(self.tau_m / u.ms), dtype=np.float64) c_m = np.asarray(u.get_mantissa(self.C_m / u.pF), dtype=np.float64) self._P11_ex = np.exp(-h / tau_ex) self._P22_ex = self._P11_ex self._P21_ex = h * self._P11_ex self._P11_in = np.exp(-h / tau_in) self._P22_in = self._P11_in self._P21_in = h * self._P11_in self._expm1_tau_m = np.expm1(-h / tau_m) self._P30 = -tau_m / c_m * self._expm1_tau_m self._P31_ex, self._P32_ex = self._alpha_propagator_p31_p32(tau_ex, tau_m, c_m, h) self._P31_in, self._P32_in = self._alpha_propagator_p31_p32(tau_in, tau_m, c_m, h) self._epsc_init = np.e / self.tau_syn_ex # 1/ms (unit-aware) self._ipsc_init = np.e / self.tau_syn_in # 1/ms (unit-aware)
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike output for a voltage tensor. Parameters ---------- V : ArrayLike or None, optional Voltage input in mV, broadcast-compatible with ``self.varshape``. If ``None``, uses current membrane state ``self.V.value``. Returns ------- out : dict Surrogate spike output from ``self.spk_fun`` with the same shape as ``V`` (or ``self.V.value`` when ``V is None``). """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
@staticmethod def _alpha_propagator_p31_p32(tau_syn: np.ndarray, tau_m: np.ndarray, c_m: np.ndarray, h_ms: float): r"""Compute alpha-kernel membrane propagator terms ``P31`` and ``P32``. Parameters ---------- tau_syn : numpy.ndarray Synaptic time constants in ms. Shape must be broadcast-compatible with state tensors. tau_m : numpy.ndarray Membrane time constants in ms, broadcast-compatible with ``tau_syn`` and positive. c_m : numpy.ndarray Membrane capacitances in pF, broadcast-compatible with ``tau_syn`` and positive. h_ms : float Integration step in ms. Returns ------- out : float Tuple ``(P31, P32)`` of ``float64`` NumPy arrays, each broadcast to the input shapes. Singular fallback limits are applied when regular formulas become numerically unreliable near ``tau_m ~= tau_syn``. Notes ----- This helper reproduces NEST ``IAFPropagatorAlpha`` masking logic with NumPy finite/normal checks to avoid catastrophic cancellation. """ # Mirrors NEST IAFPropagatorAlpha and singular fallback behavior. with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'): beta = tau_syn * tau_m / (tau_m - tau_syn) gamma = beta / c_m inv_beta = (tau_m - tau_syn) / (tau_syn * tau_m) exp_h_tau_syn = np.exp(-h_ms / tau_syn) expm1_h_tau = np.expm1(h_ms * inv_beta) p32_raw = gamma * exp_h_tau_syn * expm1_h_tau exp_h_tau_m = np.exp(-h_ms / tau_m) p32_singular = h_ms / c_m * exp_h_tau_m # NEST checks "isnormal && > 0". Approximate isnormal in NumPy. normal_min = np.finfo(np.float64).tiny p32_regular_mask = np.isfinite(p32_raw) & (np.abs(p32_raw) >= normal_min) & (p32_raw > 0.0) p32 = np.where(p32_regular_mask, p32_raw, p32_singular) h_min_regular = 1e-7 * tau_m * tau_m / np.abs(tau_m - tau_syn) p31_regular_mask = np.isfinite(h_min_regular) & (h_ms > h_min_regular) p31_regular = gamma * exp_h_tau_syn * (beta * expm1_h_tau - h_ms) p31_singular = 0.5 * h_ms * h_ms / c_m * exp_h_tau_m p31 = np.where(p31_regular_mask, p31_regular, p31_singular) return p31, p32
[docs] def update(self, x=0. * u.pA): r"""Advance the neuron by one simulation step. Parameters ---------- x : ArrayLike, optional Continuous current input in pA for this step. ``x`` is accumulated through :meth:`sum_current_inputs` and stored in ``y0`` for use on the next call (one-step delayed buffering). Returns ------- out : jax.Array Spike output tensor from :meth:`get_spike`, shape ``self.V.value.shape``. On threshold crossings, ``v_out`` is nudged above threshold by ``1e-12`` mV-equivalent to preserve positive surrogate activation. Raises ------ KeyError If simulation context does not provide ``t`` or ``dt``. AttributeError If required states are missing because :meth:`init_state` was not called. TypeError If ``x`` or stored states are not unit-compatible with expected pA / mV conversions. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV I_syn_ex = self.I_syn_ex.value # pA I_syn_in = self.I_syn_in.value # pA dI_syn_ex = self.dI_syn_ex.value # pA/ms dI_syn_in = self.dI_syn_in.value # pA/ms y0 = self.y0.value # pA r = self.refractory_step_count.value # int # Use pre-computed propagator coefficients. P11_ex = self._P11_ex P22_ex = self._P22_ex P21_ex = self._P21_ex P11_in = self._P11_in P22_in = self._P22_in P21_in = self._P21_in expm1_tau_m = self._expm1_tau_m P30 = self._P30 P31_ex = self._P31_ex P32_ex = self._P32_ex P31_in = self._P31_in P32_in = self._P32_in epsc_init = self._epsc_init ipsc_init = self._ipsc_init # Spike/current buffers for next step. w_all = self.sum_delta_inputs(0. * u.pA) w_ex = u.math.where(w_all > 0.0 * u.pA, w_all, 0.0 * u.pA) w_in = u.math.where(w_all < 0.0 * u.pA, w_all, 0.0 * u.pA) y0_next = self.sum_current_inputs(x, self.V.value) # pA # Relative voltages for propagator math. y3 = V - self.E_L # mV theta_rel = self.V_th - self.E_L # mV v_reset_rel = self.V_reset - self.E_L # mV # 1) Membrane update (unit-aware, vectorized). # The propagator coefficients are unitless ratios that, when multiplied # with the appropriately-unitful state variables, produce mV. # P30 has units ms/pF => P30 * pA = mV (since ms*pA/pF = mV). # P31 has units ms^2/pF => P31 * (pA/ms) = mV. # P32 has units ms/pF => P32 * pA = mV. # expm1_tau_m is unitless => expm1_tau_m * mV = mV. not_refractory = r == 0 y3_candidate = ( P30 * (u.get_mantissa(y0 / u.pA) + u.get_mantissa(self.I_e / u.pA)) + P31_ex * u.get_mantissa(dI_syn_ex / (u.pA / u.ms)) + P32_ex * u.get_mantissa(I_syn_ex / u.pA) + P31_in * u.get_mantissa(dI_syn_in / (u.pA / u.ms)) + P32_in * u.get_mantissa(I_syn_in / u.pA) + expm1_tau_m * u.get_mantissa(y3 / u.mV) + u.get_mantissa(y3 / u.mV) ) * u.mV if self.V_min is not None: lower_rel = self.V_min - self.E_L y3_candidate = u.math.maximum(y3_candidate, lower_rel) y3 = u.math.where(not_refractory, y3_candidate, y3) r = jnp.where(not_refractory, r, r - 1) # 2) Synaptic alpha updates (unit-aware, vectorized). I_syn_ex = (P21_ex * u.get_mantissa(dI_syn_ex / (u.pA / u.ms)) + P22_ex * u.get_mantissa(I_syn_ex / u.pA)) * u.pA dI_syn_ex = (u.get_mantissa(dI_syn_ex / (u.pA / u.ms)) * P11_ex) * (u.pA / u.ms) dI_syn_ex = dI_syn_ex + epsc_init * w_ex I_syn_in = (P21_in * u.get_mantissa(dI_syn_in / (u.pA / u.ms)) + P22_in * u.get_mantissa(I_syn_in / u.pA)) * u.pA dI_syn_in = (u.get_mantissa(dI_syn_in / (u.pA / u.ms)) * P11_in) * (u.pA / u.ms) dI_syn_in = dI_syn_in + ipsc_init * w_in # 3) Threshold + reset (unit-aware, vectorized). spike_cond = y3 >= theta_rel r = jnp.where(spike_cond, u.get_mantissa(self.ref_count), r) y3_for_spike = y3 y3 = u.math.where(spike_cond, v_reset_rel, y3) last_spike_time = u.math.where(spike_cond, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) # Write back state. self.V.value = y3 + self.E_L self.I_syn_ex.value = I_syn_ex self.I_syn_in.value = I_syn_in self.dI_syn_ex.value = dI_syn_ex self.dI_syn_in.value = dI_syn_in self.y0.value = y0_next + u.math.zeros(self.varshape) * u.pA self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) v_out = u.math.where(spike_cond, theta_rel + self.E_L + 1e-12 * u.mV, y3_for_spike + self.E_L) return self.get_spike(v_out)