Source code for brainpy_state._nest.iaf_psc_exp

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer, propagator_exp

__all__ = [
    'iaf_psc_exp',
]


class iaf_psc_exp(NESTNeuron):
    r"""NEST-compatible ``iaf_psc_exp`` neuron model.

    Description
    -----------

    ``iaf_psc_exp`` is a current-based leaky integrate-and-fire neuron with
    hard reset, fixed absolute refractory period, and exponential excitatory
    and inhibitory postsynaptic currents. The implementation follows NEST
    ``models/iaf_psc_exp.{h,cpp}`` update order, including one-step buffered
    current input and receptor-1 filtered current handling.

    **1. Continuous-Time Dynamics**


    The subthreshold membrane equation is

    .. math::

       \frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m}
       + \frac{I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}} + I_e + I_0}{C_m}

    where :math:`I_0` is the buffered current from the previous simulation
    step. Synaptic currents decay exponentially:

    .. math::

       \frac{dI_{\mathrm{syn,ex}}}{dt} = -\frac{I_{\mathrm{syn,ex}}}{\tau_{\mathrm{syn,ex}}}
       \qquad
       \frac{dI_{\mathrm{syn,in}}}{dt} = -\frac{I_{\mathrm{syn,in}}}{\tau_{\mathrm{syn,in}}}.

    NEST also defines a second current receptor :math:`I_1` that is filtered
    through the excitatory kernel; this is exposed via
    ``update(x_filtered=...)``.

    **2. Exact Step Propagator and NEST Update Ordering**


    For time step :math:`h = dt` (in ms), exact exponentials are used for
    all linear sub-systems:

    .. math::

       P_{11,\mathrm{ex}} = e^{-h/\tau_{\mathrm{syn,ex}}}, \quad
       P_{11,\mathrm{in}} = e^{-h/\tau_{\mathrm{syn,in}}}, \quad
       P_{22} = e^{-h/\tau_m},

    .. math::

       P_{20} = \frac{\tau_m}{C_m}(1 - P_{22}),

    .. math::

       P_{21}(\tau_{\mathrm{syn}}) =
       \frac{\tau_{\mathrm{syn}}\tau_m}{C_m(\tau_m - \tau_{\mathrm{syn}})}
       \left(e^{-h/\tau_m} - e^{-h/\tau_{\mathrm{syn}}}\right),

    where :math:`P_{21}` is evaluated numerically stably by
    :func:`~brainpy_state._nest._utils.propagator_exp`. Let :math:`V_\mathrm{rel} = V_m - E_L`.
    The candidate membrane update is

    .. math::

       V_{\mathrm{rel},n+1} =
       P_{22} V_{\mathrm{rel},n}
       + P_{21,\mathrm{ex}} I_{\mathrm{syn,ex},n}
       + P_{21,\mathrm{in}} I_{\mathrm{syn,in},n}
       + P_{20}(I_e + I_{0,n}).

    Per-step update order is:

    1. Update membrane potential if not refractory.
    2. Decay synaptic currents.
    3. Add filtered-current contribution to excitatory synaptic current.
    4. Add arriving spikes (positive -> excitatory, negative -> inhibitory).
    5. Threshold test, reset and refractory assignment.
    6. Store buffered currents for next step.

    **3. Escape-Noise Threshold Dynamics**


    Deterministic thresholding is used when :math:`\delta < 10^{-10}`:
    :math:`V_{\mathrm{rel}} \ge \theta`, where
    :math:`\theta = V_{th} - E_L`.

    For :math:`\delta > 0`, the model uses an exponential hazard:

    .. math::

       \phi(V) = \rho \exp\!\left(\frac{V_{\mathrm{rel}} - \theta}{\delta}\right),

    and spikes with step probability :math:`p = \phi(V)\,h\times10^{-3}`
    because :math:`\phi` is in ``1/s`` while ``h`` is in ms. Stochastic
    decisions use ``numpy.random.random``.

    **4. Stability Constraints and Computational Implications**


    - Construction enforces ``V_reset < V_th``, ``C_m > 0``, ``tau_m > 0``,
      ``tau_syn_ex > 0``, ``tau_syn_in > 0``, ``t_ref >= 0``, ``rho >= 0``,
      and ``delta >= 0``.
    - :func:`~brainpy_state._nest._utils.propagator_exp` uses a singular fallback
      :math:`(h/C_m)\exp(-h/\tau_m)` when ``tau_syn`` is numerically close
      to ``tau_m``, avoiding cancellation in
      :math:`(e^{-h/\tau_m} - e^{-h/\tau_{\mathrm{syn}}})/(\tau_m - \tau_{\mathrm{syn}})`.
    - Per-call cost is :math:`O(\prod \mathrm{varshape})` with vectorized
      NumPy operations in ``float64`` for coefficient evaluation.
    - Buffered current semantics match NEST ring-buffer timing:
      ``x``/``x_filtered`` supplied at step ``n`` are stored and consumed at
      step ``n+1``.

    Parameters
    ----------
    in_size : Size
        Population shape specification. All per-neuron parameters are
        broadcast to ``self.varshape`` derived from ``in_size``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV; scalar or array broadcastable to
        ``self.varshape``. Default is ``-70. * u.mV``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF; broadcastable and strictly
        positive. Default is ``250. * u.pF``.
    tau_m : ArrayLike, optional
        Membrane time constant :math:`\tau_m` in ms; broadcastable and
        strictly positive. Default is ``10. * u.ms``.
    t_ref : ArrayLike, optional
        Absolute refractory period :math:`t_{ref}` in ms; broadcastable and
        nonnegative. Converted to integer steps by ``ceil(t_ref / dt)``.
        Default is ``2. * u.ms``.
    V_th : ArrayLike, optional
        Spike threshold :math:`V_{th}` in mV; broadcastable to
        ``self.varshape``. Default is ``-55. * u.mV``.
    V_reset : ArrayLike, optional
        Post-spike reset potential :math:`V_{reset}` in mV; broadcastable and
        must satisfy ``V_reset < V_th`` elementwise. Default is
        ``-70. * u.mV``.
    tau_syn_ex : ArrayLike, optional
        Excitatory synaptic decay constant :math:`\tau_{\mathrm{syn,ex}}` in
        ms; broadcastable and strictly positive. Default is ``2. * u.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory synaptic decay constant :math:`\tau_{\mathrm{syn,in}}` in
        ms; broadcastable and strictly positive. Default is ``2. * u.ms``.
    I_e : ArrayLike, optional
        Constant external injected current :math:`I_e` in pA; scalar or array
        broadcastable to ``self.varshape``. Default is ``0. * u.pA``.
    rho : ArrayLike, optional
        Escape-noise base firing intensity :math:`\rho` in ``1/s``;
        broadcastable and nonnegative. Used only in stochastic mode
        (``delta > 0``). Default is ``0.01 / u.second``.
    delta : ArrayLike, optional
        Escape-noise soft-threshold width :math:`\delta` in mV; broadcastable
        and nonnegative. ``delta == 0`` reproduces deterministic thresholding.
        Default is ``0. * u.mV``.
    V_initializer : Callable, optional
        Initializer for membrane state ``V`` used by :meth:`init_state`.
        Default is ``braintools.init.Constant(-70. * u.mV)``.
    spk_fun : Callable, optional
        Surrogate spike nonlinearity used by :meth:`get_spike`. Default is
        ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` matches NEST reset behavior. Default is ``'hard'``.
    ref_var : bool, optional
        If ``True``, allocates ``self.refractory`` (boolean array) for
        external inspection of the refractory state. Default is ``False``.
    name : str or None, optional
        Optional node name passed to the parent module. Default is ``None``.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 16 28 14 16 36

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines neuron population shape ``self.varshape``.
       * - ``E_L``
         - ArrayLike, broadcastable to ``self.varshape`` (mV)
         - ``-70. * u.mV``
         - :math:`E_L`
         - Resting membrane potential.
       * - ``C_m``
         - ArrayLike, broadcastable (pF), ``> 0``
         - ``250. * u.pF``
         - :math:`C_m`
         - Membrane capacitance in voltage integration.
       * - ``tau_m``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``10. * u.ms``
         - :math:`\tau_m`
         - Membrane leak time constant.
       * - ``t_ref``
         - ArrayLike, broadcastable (ms), ``>= 0``
         - ``2. * u.ms``
         - :math:`t_{ref}`
         - Absolute refractory duration.
       * - ``V_th`` and ``V_reset``
         - ArrayLike, broadcastable (mV), with ``V_reset < V_th``
         - ``-55. * u.mV``, ``-70. * u.mV``
         - :math:`V_{th}`, :math:`V_{reset}`
         - Threshold and post-spike reset voltages.
       * - ``tau_syn_ex`` and ``tau_syn_in``
         - ArrayLike, broadcastable (ms), each ``> 0``
         - ``2. * u.ms``
         - :math:`\tau_{\mathrm{syn,ex}}`, :math:`\tau_{\mathrm{syn,in}}`
         - Exponential PSC decay constants.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * u.pA``
         - :math:`I_e`
         - Constant current injected every step.
       * - ``rho`` and ``delta``
         - ArrayLike, broadcastable; ``rho`` in ``1/s``, ``delta`` in mV,
           both ``>= 0``
         - ``0.01 / u.second``, ``0. * u.mV``
         - :math:`\rho`, :math:`\delta`
         - Escape-noise hazard parameters.
       * - ``V_initializer``
         - Callable
         - ``Constant(-70. * u.mV)``
         - --
         - Initializer for membrane state ``V``.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate function used for output spikes.
       * - ``spk_reset``
         - ``str`` (typically ``'hard'``)
         - ``'hard'``
         - --
         - Reset behavior selection in base class.
       * - ``ref_var``
         - ``bool``
         - ``False``
         - --
         - Enables explicit boolean refractory state variable.
       * - ``name``
         - ``str`` or ``None``
         - ``None``
         - --
         - Optional instance name.

    Raises
    ------
    ValueError
        Raised at construction when any validated constraint is violated:
        ``V_reset >= V_th``, nonpositive ``C_m``/``tau_m``/synaptic time
        constants, negative ``t_ref``, negative ``rho``, or negative
        ``delta``.

    Attributes
    ----------
    V : brainstate.HiddenState
        Membrane potential in mV; shape ``self.varshape``.
    i_syn_ex : brainstate.ShortTermState
        Excitatory synaptic current in pA.
    i_syn_in : brainstate.ShortTermState
        Inhibitory synaptic current in pA.
    i_0 : brainstate.ShortTermState
        Buffered receptor-0 current (pA) applied on the next simulation step.
    i_1 : brainstate.ShortTermState
        Buffered receptor-1 current (pA) filtered through the excitatory
        exponential kernel on the next simulation step.
    refractory_step_count : brainstate.ShortTermState
        Integer countdown of remaining refractory steps (``jnp.int32``).
    last_spike_time : brainstate.ShortTermState
        Simulation time of the most recent spike (ms).
    refractory : brainstate.ShortTermState
        Boolean refractory mask; only present when ``ref_var=True``.

    Notes
    -----
    - This implementation uses exact (analytical) integration of the linear
      subthreshold ODE via pre-computed propagator coefficients, matching
      NEST's update precision for fixed-step simulation.
    - Continuous current input ``x`` is combined with ``I_e`` and any
      additional current sources registered via :meth:`sum_current_inputs`;
      the combined value is buffered one step (NEST ring-buffer semantics).
    - Delta spike inputs from :meth:`sum_delta_inputs` are split by sign:
      positive weights increment ``i_syn_ex``; negative weights increment
      ``i_syn_in``.
    - The stochastic escape-noise mode (``delta > 0``) uses
      ``numpy.random.random`` and is therefore **not** JIT-compilable via
      JAX. Use ``delta=0`` for fully differentiable, JIT-compatible runs.

    Examples
    --------
    .. code-block:: python

       >>> import brainstate
       >>> import saiunit as u
       >>> from brainpy_state._nest.iaf_psc_exp import iaf_psc_exp
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = iaf_psc_exp(in_size=(3,), I_e=250. * u.pA, delta=0. * u.mV)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         out = neu.update(x=0. * u.pA, x_filtered=0. * u.pA)
       ...     _ = out.shape

    .. code-block:: python

       >>> import brainstate
       >>> import saiunit as u
       >>> from brainpy_state._nest.iaf_psc_exp import iaf_psc_exp
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = iaf_psc_exp(
       ...         in_size=10,
       ...         tau_syn_ex=2.0 * u.ms,
       ...         tau_syn_in=5.0 * u.ms,
       ...         ref_var=True,
       ...     )
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         spk = neu.update(x=300.0 * u.pA)
       ...     _ = spk.shape

    References
    ----------
    .. [1] Rotter S, Diesmann M (1999). Exact simulation of time-invariant
           linear systems with applications to neuronal modeling. Biological
           Cybernetics 81:381-402. DOI: https://doi.org/10.1007/s004220050570
    .. [2] Diesmann M, Gewaltig M-O, Rotter S, & Aertsen A (2001). State
           space analysis of synchronous spiking in cortical neural networks.
           Neurocomputing 38-40:565-571.
           DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
    .. [3] Brette R, Rudolph M, Carnevale T, et al. (2007). Simulation of
           networks of spiking neurons: a review of tools and strategies.
           Journal of Computational Neuroscience 23:349-398.
           DOI: https://doi.org/10.1007/s10827-007-0038-6

    See Also
    --------
    iaf_psc_delta : LIF neuron with delta-function PSCs (voltage-jump synapses)
    iaf_cond_exp : LIF neuron with exponential conductance synapses
    LIF : Leaky integrate-and-fire (brainpy parameterization)
    LIFRef : Leaky integrate-and-fire with explicit refractory tracking
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        tau_m: ArrayLike = 10. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -70. * u.mV,
        tau_syn_ex: ArrayLike = 2. * u.ms,
        tau_syn_in: ArrayLike = 2. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        rho: ArrayLike = 0.01 / u.second,
        delta: ArrayLike = 0. * u.mV,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.rho = braintools.init.param(rho, self.varshape)
        self.delta = braintools.init.param(delta, self.varshape)

        self.V_initializer = V_initializer
        self.ref_var = ref_var
        self._validate_parameters()

        # Precompute refractory step count.
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return
        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.tau_m <= 0.0 * u.ms):
            raise ValueError('Membrane time constant must be strictly positive.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('Synaptic time constants must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time must not be negative.')
        if np.any(self.rho < 0.0 / u.second):
            raise ValueError('Stochastic firing intensity rho must not be negative.')
        if np.any(self.delta < 0.0 * u.mV):
            raise ValueError('Threshold width delta must not be negative.')

[docs] def init_state(self, **kwargs): r"""Initialize membrane potential and all synaptic/refractory states. Parameters ---------- **kwargs : Any Unused compatibility arguments. Raises ------ ValueError If ``V_initializer`` output cannot be broadcast to the target state shape. TypeError If initializer values are incompatible with required numeric/unit conversions. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() V = braintools.init.param(self.V_initializer, self.varshape) zeros_pA = u.math.zeros(self.varshape, dtype=dftype) * u.pA self.V = brainstate.HiddenState(V) self.i_syn_ex = brainstate.ShortTermState(zeros_pA) self.i_syn_in = brainstate.ShortTermState(zeros_pA) self.i_0 = brainstate.ShortTermState(zeros_pA) self.i_1 = brainstate.ShortTermState(zeros_pA) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory) # Pre-compute propagator coefficients (constant for a given dt). self._precompute_propagators()
def _precompute_propagators(self): """Pre-compute NEST propagator coefficients from dt and model parameters.""" dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() h = float(u.math.asarray(dt_q / u.ms)) tau_ex_np = np.asarray(u.math.asarray(self.tau_syn_ex / u.ms), dtype=dftype) tau_in_np = np.asarray(u.math.asarray(self.tau_syn_in / u.ms), dtype=dftype) tau_m_np = np.asarray(u.math.asarray(self.tau_m / u.ms), dtype=dftype) C_m_np = np.asarray(u.math.asarray(self.C_m / u.pF), dtype=dftype) self._P11_ex = jnp.asarray(np.exp(-h / tau_ex_np)) self._P11_in = jnp.asarray(np.exp(-h / tau_in_np)) self._P22 = jnp.asarray(np.exp(-h / tau_m_np)) self._P21_ex = jnp.asarray(propagator_exp(tau_ex_np, tau_m_np, C_m_np, h)) self._P21_in = jnp.asarray(propagator_exp(tau_in_np, tau_m_np, C_m_np, h)) self._P20 = jnp.asarray(tau_m_np / C_m_np * (1.0 - np.exp(-h / tau_m_np))) self._h = h # Pre-compute stochastic threshold parameters. self._delta_np = jnp.asarray(np.asarray(u.math.asarray(self.delta / u.mV), dtype=dftype)) self._rho_np = jnp.asarray(np.asarray(u.math.asarray(self.rho * u.second), dtype=dftype)) self._deterministic = self._delta_np < 1e-10 self._delta_safe = jnp.where(self._deterministic, 1.0, self._delta_np)
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike activation for a voltage tensor. Scales the voltage relative to threshold and reset to compute a dimensionless argument passed to the surrogate nonlinearity ``self.spk_fun``: .. math:: \text{out} = \mathrm{spk\_fun}\!\left( \frac{V - V_{th}}{V_{th} - V_{reset}} \right). Parameters ---------- V : ArrayLike or None, optional Membrane voltage in mV, broadcast-compatible with ``self.varshape``. If ``None``, ``self.V.value`` is used. Returns ------- out : dict Surrogate spike output from ``self.spk_fun`` with the same shape as ``V`` (or ``self.V.value`` when ``V`` is ``None``). Raises ------ TypeError If ``V`` cannot participate in arithmetic with membrane parameters due to incompatible dtype or unit. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
[docs] def update(self, x=0. * u.pA, x_filtered=0. * u.pA): r"""Advance the neuron state by one simulation step. Parameters ---------- x : ArrayLike, optional Current input in pA for receptor-0 (standard current port). Scalar or array broadcastable to ``self.varshape``. The value is buffered (stored in ``self.i_0``) and applied in the **next** step, matching NEST ring-buffer semantics. Default is ``0. * u.pA``. x_filtered : ArrayLike, optional Current input in pA for receptor-1. Buffered in ``self.i_1`` and injected through excitatory exponential filtering at the next update step via ``(1 - P_{11,\mathrm{ex}}) \times i_1``. Scalar or array broadcastable to ``self.varshape``. Default is ``0. * u.pA``. Returns ------- out : jax.Array Surrogate spike output from :meth:`get_spike` with shape ``self.V.value.shape``. For neurons that fire this step, the voltage argument to :meth:`get_spike` is nudged :math:`\theta + E_L + 10^{-12}\,\text{mV}` (above threshold) to ensure a positive surrogate activation is returned even after the hard voltage reset. Raises ------ KeyError If the simulation environment context does not supply ``t`` or ``dt``. AttributeError If state variables are missing because :meth:`init_state` has not been called before ``update``. TypeError If input/state values are not unit-compatible with expected pA/mV arithmetic. ValueError If provided inputs cannot be broadcast to the internal state shape. """ t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() h = self._h # Read state variables with their natural units. V = self.V.value # mV i_syn_ex = self.i_syn_ex.value # pA i_syn_in = self.i_syn_in.value # pA i_0 = self.i_0.value # pA i_1 = self.i_1.value # pA r = self.refractory_step_count.value # int # Use pre-computed propagator coefficients. P11_ex = self._P11_ex P11_in = self._P11_in P22 = self._P22 P21_ex = self._P21_ex P21_in = self._P21_in P20 = self._P20 # Relative voltages and thresholds (unit-aware). V_rel = V - self.E_L # mV theta = self.V_th - self.E_L # mV V_reset_rel = self.V_reset - self.E_L # mV # 1. Update membrane potential if not refractory. not_refractory = r == 0 # P21 coefficients have units ms/pF which, multiplied by pA, yield mV. # P22 is unitless, P20 has units ms/pF * pA = mV. V_candidate = ( V_rel * P22 + i_syn_ex * (P21_ex * (u.mV / u.pA)) + i_syn_in * (P21_in * (u.mV / u.pA)) + (self.I_e + i_0) * (P20 * (u.mV / u.pA)) ) V_rel = u.math.where(not_refractory, V_candidate, V_rel) r = u.math.where(not_refractory, r, r - 1) # 2. Decay synaptic currents. i_syn_ex = i_syn_ex * P11_ex i_syn_in = i_syn_in * P11_in # 3. Receptor type 1 current filtered through excitatory synapse. i_syn_ex = i_syn_ex + (1.0 - P11_ex) * i_1 # 4. Add arriving spikes (positive -> excitatory, negative -> inhibitory). w_all = self.sum_delta_inputs(u.math.zeros_like(self.i_syn_ex.value)) w_ex = u.math.where(w_all > 0.0 * u.pA, w_all, 0.0 * u.pA) w_in = u.math.where(w_all < 0.0 * u.pA, w_all, 0.0 * u.pA) i_syn_ex = i_syn_ex + w_ex i_syn_in = i_syn_in + w_in # Buffered current inputs for next step (one-step delay). new_i_0 = self.sum_current_inputs(x, self.V.value) new_i_1 = u.math.asarray(x_filtered) + u.math.zeros(self.varshape) * u.pA # 5. Threshold test, reset and refractory assignment. # Deterministic thresholding when delta < 1e-10 mV. det_spike = V_rel >= theta # Stochastic escape-noise: phi * h * 1e-3 (phi in 1/s, h in ms). V_rel_np_val = u.math.asarray(V_rel / u.mV) theta_np_val = u.math.asarray(theta / u.mV) phi = self._rho_np * jnp.exp((V_rel_np_val - theta_np_val) / self._delta_safe) stoch_spike = jnp.asarray(np.random.random(size=self.varshape)) < phi * h * 1e-3 spike_cond = jnp.where(self._deterministic, det_spike, stoch_spike) r = u.math.where(spike_cond, self.ref_count, r) V_before_reset = V_rel V_rel = u.math.where(spike_cond, V_reset_rel, V_rel) # 6. Write back state. self.V.value = V_rel + self.E_L self.i_syn_ex.value = i_syn_ex self.i_syn_in.value = i_syn_in self.i_0.value = new_i_0 + u.math.zeros(self.varshape) * u.pA self.i_1.value = new_i_1 self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) last_spike_time = u.math.where(spike_cond, t + dt_q, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) # For surrogate spike output, nudge voltage above threshold on spike. V_out = u.math.where(spike_cond, theta + self.E_L + 1e-12 * u.mV, V_before_reset + self.E_L) return self.get_spike(V_out)