Source code for brainpy_state._nest.iaf_psc_exp_ps

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable, Iterable

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer, propagator_exp

__all__ = [
    'iaf_psc_exp_ps',
]


class iaf_psc_exp_ps(NESTNeuron):
    r"""NEST-compatible ``iaf_psc_exp_ps`` with precise spike times.

    Description
    -----------

    ``iaf_psc_exp_ps`` is a current-based leaky integrate-and-fire neuron with
    exponential excitatory/inhibitory PSC states and off-grid event/spike
    timing. The implementation follows NEST
    ``models/iaf_psc_exp_ps.{h,cpp}`` semantics: within-step event ordering by
    precise offsets, exact closed-form mini-step propagation, sub-step
    threshold localization by root search, and refractory release modeled as an
    explicit pseudo-event.

    **1. Continuous-time dynamics and exact integration**

    Let :math:`U = V_m - E_L`, :math:`I_{ex}` and :math:`I_{in}` be
    excitatory/inhibitory PSC states (pA), and :math:`y_0` the one-step
    buffered continuous input current (pA). Subthreshold dynamics are

    .. math::

       \frac{dU}{dt} = -\frac{U}{\tau_m}
       + \frac{I_e + y_0 + I_{ex} + I_{in}}{C_m},

    .. math::

       \frac{dI_{ex}}{dt} = -\frac{I_{ex}}{\tau_{syn,ex}}, \qquad
       \frac{dI_{in}}{dt} = -\frac{I_{in}}{\tau_{syn,in}}.

    Over a mini-interval :math:`\Delta t`, exact integration gives

    .. math::

       U(t+\Delta t) = P_{20}(\Delta t)\,(I_e+y_0)
       + P_{21,ex}(\Delta t)\,I_{ex}(t)
       + P_{21,in}(\Delta t)\,I_{in}(t)
       + U(t)e^{-\Delta t/\tau_m},

    where
    :math:`P_{20}=-\frac{\tau_m}{C_m}\left(e^{-\Delta t/\tau_m}-1\right)` and
    :math:`P_{21,X}` are evaluated by
    :func:`propagator_exp` (from ``_utils``). PSC states decay exactly via
    :math:`I_X(t+\Delta t)=I_X(t)e^{-\Delta t/\tau_{syn,X}}`.

    **2. Precise-time event processing**

    Event offsets use NEST convention: ``offset=dt`` at step start and
    ``offset=0`` at step end. For each global step:

    1. Build local event list from ``spike_events`` and on-grid delta input
       (always added at ``offset=0``).
    2. Sort events in descending offset and split the step into mini-intervals.
    3. Propagate exactly on each mini-interval.
    4. If :math:`U` reaches threshold, solve
       :math:`f(\delta)=U(\delta)-U_{th}=0` with bounded bisection
       (64 iterations) to obtain off-grid spike time.
    5. Reset to ``V_reset`` and enter refractory state; release from refractory
       occurs through a pseudo-event when
       ``step_idx + 1 - last_spike_step == ceil(t_ref / dt)``.

    **3. Assumptions, constraints, and computational complexity**

    - Parameters are scalar or broadcastable to ``self.varshape``.
    - Construction-time constraints enforce
      ``V_reset < V_th``, ``C_m > 0``, ``tau_m > 0``,
      ``tau_syn_ex > 0``, ``tau_syn_in > 0``, and when ``V_min`` is provided:
      ``V_reset >= V_min``.
    - Runtime requires ``ceil(t_ref / dt) >= 1``.
    - All precise offsets must satisfy ``0 <= offset <= dt``.
    - Continuous input ``x`` is buffered (stored into ``y0`` for the next
      global step), matching NEST current-event timing.
    - Per-step complexity is
      :math:`O(|\mathrm{state}| \cdot K)` for ``K`` local events, plus root
      search cost on threshold-crossing mini-intervals.

    Parameters
    ----------
    in_size : Size
        Population shape specification. Model parameters and states are
        broadcast to ``self.varshape`` derived from ``in_size``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV, broadcastable to ``self.varshape``.
        Default is ``-70. * u.mV``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF, broadcastable to
        ``self.varshape``. Must be strictly positive elementwise.
        Default is ``250. * u.pF``.
    tau_m : ArrayLike, optional
        Membrane time constant :math:`\tau_m` in ms, broadcastable to
        ``self.varshape``. Must be strictly positive elementwise.
        Default is ``10. * u.ms``.
    t_ref : ArrayLike, optional
        Absolute refractory duration :math:`t_{ref}` in ms, broadcastable to
        ``self.varshape``. Converted at runtime to steps using
        ``ceil(t_ref / dt)`` and must produce at least one step.
        Default is ``2. * u.ms``.
    V_th : ArrayLike, optional
        Threshold voltage :math:`V_{th}` in mV, broadcastable to
        ``self.varshape``. Default is ``-55. * u.mV``.
    V_reset : ArrayLike, optional
        Reset voltage :math:`V_{reset}` in mV, broadcastable to
        ``self.varshape``. Must satisfy ``V_reset < V_th`` elementwise.
        Default is ``-70. * u.mV``.
    tau_syn_ex : ArrayLike, optional
        Excitatory PSC decay constant :math:`\tau_{syn,ex}` in ms,
        broadcastable to ``self.varshape`` and strictly positive.
        Default is ``2. * u.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory PSC decay constant :math:`\tau_{syn,in}` in ms,
        broadcastable to ``self.varshape`` and strictly positive.
        Default is ``2. * u.ms``.
    I_e : ArrayLike, optional
        Constant external current :math:`I_e` in pA, broadcastable to
        ``self.varshape``. Added in each mini-step propagation.
        Default is ``0. * u.pA``.
    V_min : ArrayLike or None, optional
        Optional lower bound :math:`V_{min}` in mV, broadcastable to
        ``self.varshape``. If ``None``, no lower clip is applied.
        Default is ``None``.
    V_initializer : Callable, optional
        Initializer used by :meth:`init_state` for membrane state ``V``.
        Must return mV-compatible values with shape compatible with
        ``self.varshape`` (and optional batch prefix). Default is
        ``braintools.init.Constant(-70. * u.mV)``.
    spk_fun : Callable, optional
        Surrogate spike function used by :meth:`get_spike` and
        :meth:`update`. Receives normalized threshold distance tensor.
        Default is ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy forwarded to :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` matches NEST hard reset. Default is ``'hard'``.
    ref_var : bool, optional
        If ``True``, creates exposed ``self.refractory`` mirroring
        ``self.is_refractory`` for inspection. Default is ``False``.
    name : str or None, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 17 28 14 16 35

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines ``self.varshape`` for parameter/state broadcasting.
       * - ``E_L``
         - ArrayLike, broadcastable to ``self.varshape`` (mV)
         - ``-70. * u.mV``
         - :math:`E_L`
         - Resting potential and voltage-offset origin.
       * - ``C_m``
         - ArrayLike, broadcastable (pF), ``> 0``
         - ``250. * u.pF``
         - :math:`C_m`
         - Converts current terms to membrane-rate contribution.
       * - ``tau_m``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``10. * u.ms``
         - :math:`\tau_m`
         - Leak time constant in exact subthreshold propagation.
       * - ``t_ref``
         - ArrayLike, broadcastable (ms), runtime ``ceil(t_ref/dt) >= 1``
         - ``2. * u.ms``
         - :math:`t_{ref}`
         - Absolute refractory duration.
       * - ``V_th`` and ``V_reset``
         - ArrayLike, broadcastable (mV), with ``V_reset < V_th``
         - ``-55. * u.mV``, ``-70. * u.mV``
         - :math:`V_{th}`, :math:`V_{reset}`
         - Threshold and post-spike reset levels.
       * - ``tau_syn_ex`` and ``tau_syn_in``
         - ArrayLike, broadcastable (ms), each ``> 0``
         - ``2. * u.ms``
         - :math:`\tau_{syn,ex}`, :math:`\tau_{syn,in}`
         - Exponential PSC decay constants.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * u.pA``
         - :math:`I_e`
         - Constant injected current added every mini-step.
       * - ``V_min``
         - ArrayLike broadcastable (mV) or ``None``
         - ``None``
         - :math:`V_{min}`
         - Optional lower clamp applied after membrane propagation.
       * - ``V_initializer``
         - Callable returning mV-compatible values
         - ``Constant(-70. * u.mV)``
         - --
         - Initializes membrane state ``V``.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate spike output nonlinearity.
       * - ``spk_reset``
         - str
         - ``'hard'``
         - --
         - Reset mode inherited from base ``Neuron``.
       * - ``ref_var``
         - bool
         - ``False``
         - --
         - Allocate exposed ``refractory`` mirror state.
       * - ``name``
         - str | None
         - ``None``
         - --
         - Optional node name.

    Raises
    ------
    ValueError
        If validated constraints fail (for example ``V_reset >= V_th``,
        non-positive capacitance/time constants, ``V_reset < V_min``,
        ``ceil(t_ref / dt) < 1``, or event offsets outside ``[0, dt]``).
    TypeError
        If provided arguments are incompatible with expected units/callables
        (mV, pA, pF, ms).
    KeyError
        If simulation context values ``t`` and/or ``dt`` are missing when
        :meth:`update` is called.
    AttributeError
        If :meth:`update` is called before :meth:`init_state` creates required
        runtime states.

    Attributes
    ----------
    V : HiddenState
        Membrane potential state in mV.
    I_syn_ex : ShortTermState
        Excitatory PSC state in pA.
    I_syn_in : ShortTermState
        Inhibitory PSC state in pA.
    y0 : ShortTermState
        One-step buffered continuous current in pA.
    is_refractory : ShortTermState
        Boolean refractory mask.
    last_spike_step : ShortTermState
        Step index of latest emitted spike.
    last_spike_offset : ShortTermState
        Precise offset (ms) from right step boundary for latest spike.
    last_spike_time : ShortTermState
        Absolute precise spike time in ms.
    refractory : ShortTermState
        Optional mirror of ``is_refractory`` when ``ref_var=True``.

    Notes
    -----
    - ``spike_events`` accepts ``(offset, weight)`` tuples or
      ``{'offset': ..., 'weight': ...}`` dicts.
    - Offsets are in ms and measured from the right edge of the current step.
    - Positive event weights contribute to excitatory PSC state; negative
      weights contribute to inhibitory PSC state.
    - Internal propagation and root finding are evaluated in NumPy float64 and
      written back into BrainUnit states at end of step.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_exp_ps(in_size=2, I_e=200.0 * u.pA)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         spk = neu.update()
       ...     _ = spk.shape

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_exp_ps(in_size=1)
       ...     neu.init_state()
       ...     ev = [{'offset': 0.08 * u.ms, 'weight': 120.0 * u.pA}]
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         _ = neu.update(spike_events=ev)

    References
    ----------
    .. [1] NEST source: ``models/iaf_psc_exp_ps.h`` and
           ``models/iaf_psc_exp_ps.cpp``.
    .. [2] Rotter S, Diesmann M (1999). Exact simulation of time-invariant
           linear systems with applications to neuronal modeling.
           Biological Cybernetics 81:381-402.
           DOI: https://doi.org/10.1007/s004220050570
    .. [3] Morrison A, Straube S, Plesser HE, Diesmann M (2007). Exact
           subthreshold integration with continuous spike times in discrete
           time neural network simulations. Neural Computation 19(1):47-79.
           DOI: https://doi.org/10.1162/neco.2007.19.1.47
    .. [4] Hanuschkin A, Kunkel S, Helias M, Morrison A, Diesmann M (2010).
           A general and efficient method for incorporating exact spike times
           in globally time-driven simulations. Frontiers in Neuroinformatics.
           DOI: https://doi.org/10.3389/fninf.2010.00113
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        tau_m: ArrayLike = 10. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -70. * u.mV,
        tau_syn_ex: ArrayLike = 2. * u.ms,
        tau_syn_in: ArrayLike = 2. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        V_min: ArrayLike = None,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.V_min = None if V_min is None else braintools.init.param(V_min, self.varshape)
        self.V_initializer = V_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        # Precompute refractory step count.
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_m)):
            return

        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if self.V_min is not None and np.any(self.V_reset < self.V_min):
            raise ValueError('Reset potential must be greater equal minimum potential.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.tau_m <= 0.0 * u.ms):
            raise ValueError('Membrane time constant must be strictly positive.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize membrane, synaptic, and precise-timing runtime states. This method allocates all internal state variables required for precise spike-time simulation. Membrane potential ``V`` is initialized using ``self.V_initializer``, synaptic currents and buffered inputs are initialized to zero, and spike-tracking states are initialized to sentinel values (``last_spike_step = -1``, ``last_spike_time = -1e7 ms``) indicating no prior spike events. Parameters ---------- **kwargs : Any Unused compatibility arguments for subclass extension. Raises ------ ValueError If initializer outputs cannot be broadcast to state shape ``self.varshape`` or if shapes are incompatible. TypeError If initializer outputs are not unit-compatible with expected state units (mV for voltage, pA for currents, ms for time, bool for flags). AttributeError If ``self.V_initializer`` is not callable or does not produce valid output for the requested shape. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros_like(u.math.asarray(V / u.mV)) self.V = brainstate.HiddenState(V) self.I_syn_ex = brainstate.ShortTermState(zeros * u.pA) self.I_syn_in = brainstate.ShortTermState(zeros * u.pA) self.y0 = brainstate.ShortTermState(zeros * u.pA) self.is_refractory = brainstate.ShortTermState(np.zeros(self.varshape, dtype=bool)) self.last_spike_step = brainstate.ShortTermState( u.math.full(self.varshape, -1, dtype=ditype) ) self.last_spike_offset = brainstate.ShortTermState(zeros * u.ms) self.last_spike_time = brainstate.ShortTermState( u.math.full(self.varshape, -1e7 * u.ms) ) if self.ref_var: self.refractory = brainstate.ShortTermState(np.zeros(self.varshape, dtype=bool))
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike output from membrane potential. Applies the surrogate spike function (typically ``braintools.surrogate.ReluGrad`` or similar) to a normalized threshold-distance metric. This enables differentiable spike generation for gradient-based learning while maintaining biological spike semantics. The normalized threshold distance is computed as :math:`(V - V_{th}) / (V_{th} - V_{reset})`, which maps the voltage range between reset and threshold to ``[0, 1]``, with values above threshold producing positive outputs through the surrogate function. Parameters ---------- V : ArrayLike or None, optional Voltage tensor in mV, broadcast-compatible with ``self.varshape`` (or current batched state shape). If ``None``, uses ``self.V.value``. Default is ``None``. Returns ------- out : dict Output of ``self.spk_fun`` applied to normalized threshold distance ``(V - V_th) / (V_th - V_reset)`` with same shape as input ``V``. Typically float values in ``[0, 1]`` or similar range depending on the surrogate function's output characteristics. Raises ------ TypeError If ``V`` is not compatible with unit arithmetic in mV or if unit conversion operations fail. AttributeError If ``self.spk_fun`` is not callable or if required parameters (``V_th``, ``V_reset``) are not available. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _parse_spike_events(self, spike_events: Iterable, v_shape): r"""Parse spike events into normalized (offset_ms, weight_array) tuples. Converts mixed-format spike events (tuples or dicts) into a uniform internal representation suitable for event processing. Offsets are extracted in ms units, weights are extracted in pA units and broadcast to match the neuron population shape. Parameters ---------- spike_events : Iterable or None User-provided spike events as ``(offset, weight)`` tuples or ``{'offset': ..., 'weight': ...}`` dicts. v_shape : tuple Target shape for broadcasting weight arrays (typically ``self.V.value.shape``). Returns ------- list of tuple[float, np.ndarray] List of ``(offset_ms, weight_np)`` pairs where ``offset_ms`` is a float scalar in ms and ``weight_np`` is a float64 array broadcast to ``v_shape``. """ events = [] if spike_events is None: return events dftype = brainstate.environ.dftype() for ev in spike_events: if isinstance(ev, dict): offs = ev.get('offset', 0.0 * u.ms) w = ev.get('weight', 0.0 * u.pA) else: offs, w = ev off_ms = float(u.math.asarray(offs / u.ms)) w_np = np.asarray(u.math.asarray(w / u.pA), dtype=dftype) events.append((off_ms, np.broadcast_to(w_np, v_shape))) return events @staticmethod def _bisect_root(f, t_hi: float): r"""Find root of scalar function using bounded bisection. Locates the point where ``f(t)`` crosses zero within the interval ``[0, t_hi]`` using bisection with 64 iterations. Assumes ``f`` is continuous and monotonically increasing within the search interval. This method is used to find the precise sub-step time at which the membrane potential crosses the spike threshold during exact integration. Parameters ---------- f : Callable[[float], float] Scalar function representing threshold distance as a function of time offset within a mini-interval. Expected to be negative at ``t=0`` and positive or zero at ``t=t_hi`` for a valid crossing. t_hi : float Upper bound of search interval in ms (mini-interval duration). Returns ------- float Estimated root location in ``[0, t_hi]`` in ms. If no crossing is detected (``f(0) > 0`` or ``f(t_hi) <= 0``), returns boundary values ``0.0`` or ``t_hi`` respectively. If ``f(t_hi)`` is non-finite, returns ``t_hi``. Notes ----- The bisection uses 64 iterations, providing approximately :math:`2^{-64}` relative precision on the root location within the search interval. This is sufficient for neuroscience simulation time scales (ms resolution). """ lo = 0.0 hi = float(t_hi) f_lo = f(lo) f_hi = f(hi) if not np.isfinite(f_hi): return hi if f_lo > 0.0: return 0.0 if f_hi <= 0.0: return hi for _ in range(64): mid = 0.5 * (lo + hi) f_mid = f(mid) if f_mid > 0.0: hi = mid else: lo = mid return 0.5 * (lo + hi)
[docs] def update(self, x=0. * u.pA, spike_events=None): r"""Advance one global step with precise within-step event handling. This method implements the complete NEST-compatible precise-spike-time algorithm for ``iaf_psc_exp_ps``. Each global time step is subdivided into mini-intervals determined by spike event offsets. Within each mini-interval, membrane potential and synaptic currents are propagated exactly using closed-form exponential solutions. When the membrane potential crosses threshold, bisection root-finding (64 iterations) localizes the precise sub-step spike time. **Update sequence:** 1. Parse and validate ``spike_events`` and on-grid delta inputs. 2. Sort events in descending offset (from step start to step end). 3. For each neuron, process events sequentially: a. Propagate states exactly over each mini-interval. b. Apply event weights to PSC states (ex/in channels by sign). c. Check for threshold crossing and localize spike time if needed. d. Apply hard reset and enter refractory state on spike. e. Release from refractory via pseudo-event at calculated step. 4. Buffer incoming current ``x`` into ``y0`` for next step. 5. Compute surrogate spike output for gradient-based learning. **Implementation notes:** - All propagation uses NumPy float64 for numerical stability. - Event offsets follow NEST convention: ``offset=dt`` at step start, ``offset=0`` at step end. - Refractory neurons clamp membrane potential but allow PSC decay. - Root finding uses bounded bisection over ``[0, dt]`` with 64 iterations. Parameters ---------- x : ArrayLike, optional Continuous current input in pA for the current global step. Aggregated through :meth:`sum_current_inputs` and stored in ``y0`` for use in the next step (one-step buffering). Scalar or array-like broadcastable to ``self.V.value.shape``. Default is ``0. * u.pA``. spike_events : Iterable[tuple[Any, Any] | dict[str, Any]] or None, optional Optional off-grid events inside the current step. Each entry is ``(offset, weight)`` or ``{'offset': ..., 'weight': ...}``, where ``offset`` is in ms measured from the right step boundary and ``weight`` is in pA. Offsets must satisfy ``0 <= offset <= dt``. Positive weights update excitatory PSC; negative weights update inhibitory PSC. ``None`` means no extra precise events. On-grid delta inputs are automatically included at ``offset=0``. Default is ``None``. Returns ------- out : jax.Array Surrogate spike output from :meth:`get_spike`, shape ``self.V.value.shape``. Values correspond to ``self.spk_fun((V - V_th) / (V_th - V_reset))`` after exact piecewise propagation, event application, refractory logic, and precise spike-time localization. For neurons that spiked, the voltage is clamped slightly above threshold to ensure differentiable spike detection; for non-spiking neurons, voltage is clamped below threshold. Raises ------ ValueError If ``ceil(t_ref / dt) < 1`` (refractory period too short for time step), or if any event offset lies outside ``[0, dt]``, or if parameter constraints are violated at runtime. KeyError If simulation context values ``t`` (current time) or ``dt`` (time step) are unavailable from ``brainstate.environ``. TypeError If ``x`` or ``spike_events`` entries are not unit-compatible with pA/ms conversions, or if type conversions fail during numerical computation. AttributeError If required runtime states (``V``, ``I_syn_ex``, ``I_syn_in``, ``y0``, ``is_refractory``, etc.) are missing because :meth:`init_state` has not been called. """ import math t = brainstate.environ.get('t') dt_q = brainstate.environ.get_dt() h = float(u.math.asarray(dt_q / u.ms)) t_ms = float(u.math.asarray(t / u.ms)) step_idx = int(round(t_ms / h)) eps = np.finfo(np.float64).eps dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() v_shape = self.V.value.shape E_L = np.broadcast_to(np.asarray(u.math.asarray(self.E_L / u.mV), dtype=dftype), v_shape) y2 = np.broadcast_to(np.asarray(u.math.asarray(self.V.value / u.mV), dtype=dftype), v_shape) - E_L y1_ex = np.broadcast_to(np.asarray(u.math.asarray(self.I_syn_ex.value / u.pA), dtype=dftype), v_shape) y1_in = np.broadcast_to(np.asarray(u.math.asarray(self.I_syn_in.value / u.pA), dtype=dftype), v_shape) y0 = np.broadcast_to(np.asarray(u.math.asarray(self.y0.value / u.pA), dtype=dftype), v_shape) is_refractory = np.broadcast_to( np.asarray(u.math.asarray(self.is_refractory.value), dtype=bool), v_shape ) last_spike_step = np.broadcast_to( np.asarray(u.math.asarray(self.last_spike_step.value), dtype=ditype), v_shape ) last_spike_offset = np.broadcast_to( np.asarray(u.math.asarray(self.last_spike_offset.value / u.ms), dtype=dftype), v_shape ) last_spike_time_prev = np.broadcast_to( np.asarray(u.math.asarray(self.last_spike_time.value / u.ms), dtype=dftype), v_shape ) tau_m = np.broadcast_to(np.asarray(u.math.asarray(self.tau_m / u.ms), dtype=dftype), v_shape) tau_ex = np.broadcast_to(np.asarray(u.math.asarray(self.tau_syn_ex / u.ms), dtype=dftype), v_shape) tau_in = np.broadcast_to(np.asarray(u.math.asarray(self.tau_syn_in / u.ms), dtype=dftype), v_shape) c_m = np.broadcast_to(np.asarray(u.math.asarray(self.C_m / u.pF), dtype=dftype), v_shape) i_e = np.broadcast_to(np.asarray(u.math.asarray(self.I_e / u.pA), dtype=dftype), v_shape) u_th = np.broadcast_to( np.asarray(u.math.asarray((self.V_th - self.E_L) / u.mV), dtype=dftype), v_shape ) u_reset = np.broadcast_to( np.asarray(u.math.asarray((self.V_reset - self.E_L) / u.mV), dtype=dftype), v_shape ) u_min = -np.inf * np.ones(v_shape, dtype=dftype) if self.V_min is not None: u_min = np.broadcast_to( np.asarray(u.math.asarray((self.V_min - self.E_L) / u.mV), dtype=dftype), v_shape ) refr_steps = np.broadcast_to( np.asarray(u.math.asarray(self.ref_count), dtype=ditype), v_shape ) if np.any(refr_steps < 1): raise ValueError('Refractory time must be at least one time step.') # Events in a step, sorted from step start (offset=dt) to step end (offset=0). events = self._parse_spike_events(spike_events, v_shape) on_grid = np.broadcast_to( np.asarray(u.math.asarray(self.sum_delta_inputs(0. * u.pA) / u.pA), dtype=dftype), v_shape ) events.append((0.0, on_grid)) events.sort(key=lambda z: z[0], reverse=True) for off, _ in events: if off < 0.0 or off > h: raise ValueError('All spike event offsets must satisfy 0 <= offset <= dt.') y0_next = np.broadcast_to( np.asarray(u.math.asarray(self.sum_current_inputs(x, self.V.value) / u.pA), dtype=dftype), v_shape ) y0_new = np.empty_like(y0) y1_ex_new = np.empty_like(y1_ex) y1_in_new = np.empty_like(y1_in) y2_new = np.empty_like(y2) refr_new = np.empty_like(is_refractory) last_step_new = np.empty_like(last_spike_step) last_offset_new = np.empty_like(last_spike_offset) last_time_new = np.empty_like(last_spike_time_prev) spike_mask = np.zeros(v_shape, dtype=bool) v_for_spike = np.empty_like(y2) for idx in np.ndindex(v_shape): y0_i = float(y0[idx]) y1e_i = float(y1_ex[idx]) y1i_i = float(y1_in[idx]) y2_i = float(y2[idx]) refr_i = bool(is_refractory[idx]) last_step_i = int(last_spike_step[idx]) last_off_i = float(last_spike_offset[idx]) spike_time_i = float(last_spike_time_prev[idx]) tau_m_i = float(tau_m[idx]) tau_ex_i = float(tau_ex[idx]) tau_in_i = float(tau_in[idx]) c_m_i = float(c_m[idx]) i_e_i = float(i_e[idx]) u_th_i = float(u_th[idx]) u_reset_i = float(u_reset[idx]) u_min_i = float(u_min[idx]) refr_steps_i = int(refr_steps[idx]) did_spike = False before = [y0_i, y1e_i, y1i_i, y2_i] def set_before(): before[0] = y0_i before[1] = y1e_i before[2] = y1i_i before[3] = y2_i def threshold_distance(dt_local): P20 = -tau_m_i / c_m_i * math.expm1(-dt_local / tau_m_i) P21e = propagator_exp(np.asarray(tau_ex_i), np.asarray(tau_m_i), np.asarray(c_m_i), dt_local) P21i = propagator_exp(np.asarray(tau_in_i), np.asarray(tau_m_i), np.asarray(c_m_i), dt_local) y2_r = P20 * (i_e_i + before[0]) + P21e * before[1] + P21i * before[2] + before[3] * math.exp( -dt_local / tau_m_i) return y2_r - u_th_i def propagate(dt_local): nonlocal y1e_i, y1i_i, y2_i if dt_local <= 0.0: return if not refr_i: P20 = -tau_m_i / c_m_i * math.expm1(-dt_local / tau_m_i) P21e = propagator_exp(np.asarray(tau_ex_i), np.asarray(tau_m_i), np.asarray(c_m_i), dt_local) P21i = propagator_exp(np.asarray(tau_in_i), np.asarray(tau_m_i), np.asarray(c_m_i), dt_local) y2_i = P20 * (i_e_i + y0_i) + P21e * y1e_i + P21i * y1i_i + y2_i * math.exp(-dt_local / tau_m_i) y2_i = max(y2_i, u_min_i) y1e_i = y1e_i * math.exp(-dt_local / tau_ex_i) y1i_i = y1i_i * math.exp(-dt_local / tau_in_i) def emit_spike(t0, dt_local): nonlocal y2_i, refr_i, last_step_i, last_off_i, spike_time_i, did_spike root = self._bisect_root(threshold_distance, dt_local) spike_off = h - (t0 + root) spike_off = min(h, max(0.0, spike_off)) last_step_i = step_idx + 1 last_off_i = spike_off y2_i = u_reset_i refr_i = True spike_time_i = t_ms + h - spike_off did_spike = True def emit_instant_spike(spike_off): nonlocal y2_i, refr_i, last_step_i, last_off_i, spike_time_i, did_spike so = min(h, max(0.0, spike_off)) last_step_i = step_idx + 1 last_off_i = so y2_i = u_reset_i refr_i = True spike_time_i = t_ms + h - so did_spike = True if (not refr_i) and (y2_i >= u_th_i): emit_instant_spike(h * (1.0 - eps)) local_events = [(off, w[idx], False) for off, w in events] if refr_i and (step_idx + 1 - last_step_i == refr_steps_i): local_events.append((last_off_i, 0.0, True)) local_events.sort(key=lambda z: z[0], reverse=True) last_off = h if len(local_events) == 0: propagate(h) if y2_i >= u_th_i: set_before() emit_spike(0.0, h) else: for ev_off, ev_w, end_of_refract in local_events: ministep = last_off - ev_off if ministep > 0.0: set_before() propagate(ministep) if y2_i >= u_th_i: emit_spike(h - last_off, ministep) if end_of_refract: refr_i = False else: if ev_w >= 0.0: y1e_i += ev_w else: y1i_i += ev_w set_before() last_off = ev_off if last_off > 0.0: set_before() propagate(last_off) if y2_i >= u_th_i: emit_spike(h - last_off, last_off) y0_i = float(y0_next[idx]) y0_new[idx] = y0_i y1_ex_new[idx] = y1e_i y1_in_new[idx] = y1i_i y2_new[idx] = y2_i refr_new[idx] = refr_i last_step_new[idx] = last_step_i last_offset_new[idx] = last_off_i last_time_new[idx] = spike_time_i spike_mask[idx] = did_spike v_for_spike[idx] = (u_th_i + 1e-12) if did_spike else min(y2_i, u_th_i - 1e-12) self.y0.value = y0_new * u.pA self.I_syn_ex.value = y1_ex_new * u.pA self.I_syn_in.value = y1_in_new * u.pA self.V.value = (y2_new + E_L) * u.mV self.is_refractory.value = jnp.asarray(refr_new, dtype=bool) self.last_spike_step.value = jnp.asarray(last_step_new, dtype=ditype) self.last_spike_offset.value = last_offset_new * u.ms self.last_spike_time.value = jax.lax.stop_gradient(last_time_new * u.ms) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.is_refractory.value) return self.get_spike((v_for_spike + E_L) * u.mV)