Source code for brainpy_state._nest.iaf_psc_delta

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron

__all__ = [
    'iaf_psc_delta',
]


class iaf_psc_delta(NESTNeuron):
    r"""NEST-compatible ``iaf_psc_delta`` neuron model.

    Description
    -----------

    ``iaf_psc_delta`` is a current-based leaky integrate-and-fire neuron with
    hard threshold/reset, absolute refractory period, and delta-shaped
    synaptic events represented as instantaneous membrane-voltage jumps
    (weights in mV). The implementation follows the NEST model
    ``iaf_psc_delta`` update semantics, including refractory handling and
    step-wise exact subthreshold propagation.

    **1. Continuous-Time Dynamics and Exact Per-Step Propagator**

    The membrane dynamics are

    .. math::

       \frac{dV_\text{m}}{dt} = -\frac{V_{\text{m}} - E_\text{L}}{\tau_{\text{m}}}
       + \dot{\Delta}_{\text{syn}}
       + \frac{I_{\text{syn}} + I_\text{e}}{C_{\text{m}}}

    where :math:`I_\text{syn}` is the sum of continuous current inputs and
    :math:`\dot{\Delta}_{\text{syn}}` captures impulse-like jumps from
    delta synapses.

    For fixed simulation step :math:`h=dt` and piecewise-constant current
    :math:`I_k`, exact integration of the linear subthreshold ODE yields

    .. math::

       V_{k+1}^{\mathrm{cand}}
       = E_L + (V_k - E_L)e^{-h/\tau_m}
       + \frac{\tau_m}{C_m}\left(I_k + I_e\right)\left(1 - e^{-h/\tau_m}\right),

    which is implemented directly in :meth:`update`. This is equivalent to
    the propagator formulation used in NEST for this linear system.

    **2. Spike Condition, Reset, and Refractory Countdown**

    After adding delta-input jump :math:`\Delta_{\text{syn},k}`, a spike is
    emitted at step end if the post-update potential crosses threshold:

    .. math::

       V_k^{\mathrm{post}} \ge V_{th}.

    On spike:

    .. math::

       V \leftarrow V_{reset}, \qquad
       r \leftarrow \left\lceil \frac{t_{ref}}{dt} \right\rceil,

    where :math:`r` is the integer refractory-step counter. While
    :math:`r > 0`, the membrane is clamped (no subthreshold integration is
    committed), then :math:`r` decrements by one each simulation step.

    **3. Delta Synapses, Voltage Jumps, and Charge Interpretation**

    The change in membrane potential due to synaptic inputs can be formulated as:

    .. math::

       \dot{\Delta}_{\text{syn}}(t) = \sum_{j} w_j \sum_k \delta(t-t_j^k-d_j) \;,

    where :math:`j` indexes either excitatory (:math:`w_j > 0`) or inhibitory
    (:math:`w_j < 0`) presynaptic neurons, :math:`k` indexes the spike times of
    neuron :math:`j`, :math:`d_j` is the delay from neuron :math:`j`, and
    :math:`\delta` is the Dirac delta distribution. This implies that the jump in
    voltage upon a single synaptic input spike is

    .. math::

       \Delta_{\text{syn}} = w \;,

    where :math:`w` is synaptic weight in mV. Positive weights are excitatory
    and negative weights are inhibitory.

    The change in voltage caused by the synaptic input can be interpreted as being
    caused by individual post-synaptic currents (PSCs) given by

    .. math::

       i_{\text{syn}}(t) = C_{\text{m}} \cdot w \cdot \delta(t) \;.

    As a consequence, the total charge :math:`q` transferred by a single PSC is

    .. math::

       q = \int_0^{\infty}  i_{\text{syn}}(t)\, dt = C_{\text{m}} \cdot w \;.

    **4. Assumptions, Constraints, and Computational Implications**

    - The model assumes unit-compatible parameters and broadcast-compatible
      shapes against ``self.varshape``.
    - ``V_min`` is optional; when provided, candidate voltage is clipped with
      ``max(V, V_min)`` before threshold evaluation.
    - Per-step compute is :math:`O(\prod \mathrm{varshape})` with vectorized
      elementwise operations.
    - ``refractory_input=False`` discards delta events that arrive during
      refractory clamping, while ``refractory_input=True`` stores a decayed
      contribution that is released when refractoriness ends.

    .. note::

       This implementation uses exact integration for subthreshold dynamics
       and NEST-compatible conversion of refractory duration to grid steps via
       ``ceil(t_ref / dt)``.

    Parameters
    ----------

    in_size : Size
        Population shape specification. All neuron parameters are broadcast to
        ``self.varshape`` derived from ``in_size``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV; scalar or array broadcastable to
        ``self.varshape``. Default is ``-70. * u.mV``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF; broadcastable to
        ``self.varshape``. Expected positive for physical behavior. Default is
        ``250. * u.pF``.
    tau_m : ArrayLike, optional
        Membrane time constant :math:`\tau_m` in ms; broadcastable to
        ``self.varshape``. Expected positive. Default is ``10. * u.ms``.
    t_ref : ArrayLike, optional
        Absolute refractory duration :math:`t_{ref}` in ms. Converted to
        integer simulation steps using ``ceil(t_ref / dt)``. Default is
        ``2. * u.ms``.
    V_th : ArrayLike, optional
        Spike threshold :math:`V_{th}` in mV; broadcastable to ``self.varshape``.
        Default is ``-55. * u.mV``.
    V_reset : ArrayLike, optional
        Post-spike reset potential :math:`V_{reset}` in mV; broadcastable to
        ``self.varshape``. Default is ``-70. * u.mV``.
    I_e : ArrayLike, optional
        Constant external current :math:`I_e` in pA; scalar or array
        broadcastable to ``self.varshape``. Default is ``0. * u.pA``.
    V_min : ArrayLike or None, optional
        Optional lower membrane bound :math:`V_{min}` in mV. If ``None``,
        no lower clipping is applied. Default is ``None``.
    V_initializer : Callable, optional
        Initializer for membrane state ``V`` in :meth:`init_state`. Output
        must be shape-compatible with ``self.varshape`` (and optional batch
        prefix). Default is ``braintools.init.Constant(-70. * u.mV)``.
    spk_fun : Callable, optional
        Surrogate spike function used by :meth:`get_spike`. Default is
        ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` reproduces NEST hard reset behavior. Default is ``'hard'``.
    refractory_input : bool, optional
        If ``False``, delta inputs during refractory are ignored. If ``True``,
        refractory-arriving delta jumps are accumulated in
        ``refractory_spike_buffer`` and applied after refractory release.
        Default is ``False``.
    ref_var : bool, optional
        If ``True``, allocate boolean refractory state ``self.refractory`` for
        inspection. Default is ``False``.
    name : str or None, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 17 28 14 16 35

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines population/state shape ``self.varshape``.
       * - ``E_L``
         - ArrayLike, broadcastable to ``self.varshape`` (mV)
         - ``-70. * u.mV``
         - :math:`E_L`
         - Resting membrane potential.
       * - ``C_m``
         - ArrayLike, broadcastable (pF)
         - ``250. * u.pF``
         - :math:`C_m`
         - Membrane capacitance in subthreshold propagator.
       * - ``tau_m``
         - ArrayLike, broadcastable (ms)
         - ``10. * u.ms``
         - :math:`\tau_m`
         - Membrane leak time constant.
       * - ``t_ref``
         - ArrayLike, broadcastable (ms), step-converted by ``ceil``
         - ``2. * u.ms``
         - :math:`t_{ref}`
         - Absolute refractory duration.
       * - ``V_th`` and ``V_reset``
         - ArrayLike, broadcastable (mV)
         - ``-55. * u.mV``, ``-70. * u.mV``
         - :math:`V_{th}`, :math:`V_{reset}`
         - Threshold and reset voltages.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * u.pA``
         - :math:`I_e`
         - Constant injected current.
       * - ``V_min``
         - ArrayLike broadcastable (mV) or ``None``
         - ``None``
         - :math:`V_{min}`
         - Optional lower clamp before threshold test.
       * - ``V_initializer``
         - Callable
         - ``Constant(-70. * u.mV)``
         - --
         - Initializes membrane state ``V``.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate spike output function.
       * - ``spk_reset``
         - str
         - ``'hard'``
         - --
         - Reset mode inherited from base neuron class.
       * - ``refractory_input``
         - bool
         - ``False``
         - --
         - Controls refractory-period treatment of delta inputs.
       * - ``ref_var``
         - bool
         - ``False``
         - --
         - Enables persistent refractory boolean state.
       * - ``name``
         - str | None
         - ``None``
         - --
         - Optional node identifier.

    Raises
    ------
    ValueError
        If parameter initialization or broadcasting fails (for example due to
        incompatible array shape in ``braintools.init.param``).
    TypeError
        If provided values are not compatible with expected units/types
        (mV, ms, pF, pA, or callable initializers/spike functions).
    KeyError
        At runtime, if required simulation context entries (for example ``t``
        or ``dt``) are missing when :meth:`update` is called.
    AttributeError
        If :meth:`update` is called before :meth:`init_state` creates required
        state variables.

    Attributes
    ----------
    V : HiddenState
        Membrane potential.
    last_spike_time : ShortTermState
        Time of the last spike, used to implement the refractory period.
    refractory : HiddenState
        Neuron refractory state (only present if ``ref_var=True``).

    Notes
    -----
    - State variables are ``V``, ``last_spike_time``,
      ``refractory_step_count``, and ``refractory_spike_buffer``. ``refractory``
      exists only when ``ref_var=True``.
    - Continuous current input ``x`` is combined with ``I_e`` through
      :meth:`sum_current_inputs` in the same simulation step.
    - Delta events from :meth:`sum_delta_inputs` are interpreted in mV and
      added as instantaneous voltage jumps.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_delta(in_size=10, t_ref=2.0 * u.ms)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         spk = neu.update(x=500.0 * u.pA)
       ...     _ = spk.shape

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     neu = brainpy.state.iaf_psc_delta(in_size=(2,), V_min=-80.0 * u.mV)
       ...     neu.init_state()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         _ = neu.update(x=120.0 * u.pA)

    References
    ----------
    .. [1] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear
           systems with applications to neuronal modeling. Biological Cybernetics
           81:381-402. DOI: https://doi.org/10.1007/s004220050570
    .. [2] Diesmann M, Gewaltig M-O, Rotter S, & Aertsen A (2001). State space
           analysis of synchronous spiking in cortical neural networks.
           Neurocomputing 38-40:565-571.
           DOI: https://doi.org/10.1016/S0925-2312(01)00409-X

    See Also
    --------
    LIF : Leaky integrate-and-fire with current-based synapses
    LIFRef : Leaky integrate-and-fire with refractory period (brainpy parameterization)
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        tau_m: ArrayLike = 10. * u.ms,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -70. * u.mV,
        I_e: ArrayLike = 0. * u.pA,
        V_min: ArrayLike = None,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        refractory_input: bool = False,
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # parameters
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.V_min = V_min
        self.V_initializer = V_initializer
        self.refractory_input = refractory_input
        self.ref_var = ref_var

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

[docs] def init_state(self, **kwargs): r"""Initialize membrane and refractory runtime states. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If initializer outputs cannot be broadcast to target state shape. TypeError If initializer values are incompatible with required numeric/unit conversions. """ ditype = brainstate.environ.ditype() V = braintools.init.param(self.V_initializer, self.varshape) self.V = brainstate.HiddenState(V) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.refractory_spike_buffer = brainstate.ShortTermState(u.math.zeros_like(V)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike activation for a voltage tensor. Parameters ---------- V : ArrayLike or None, optional Membrane voltage input in mV, broadcast-compatible with ``self.varshape``. If ``None``, the method uses ``self.V.value``. Returns ------- out : dict Surrogate spike output from ``self.spk_fun`` with the same shape as ``V`` (or ``self.V.value`` when ``V`` is ``None``). Raises ------ TypeError If ``V`` cannot participate in arithmetic with membrane parameters due to incompatible dtype/unit. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
[docs] def update(self, x=0. * u.pA): r"""Advance the neuron by one simulation step. Parameters ---------- x : ArrayLike, optional External current input in pA for this step. It is combined with ``I_e`` and additional current sources from :meth:`sum_current_inputs`. Returns ------- out : jax.Array Surrogate spike output from :meth:`get_spike` with shape ``self.V.value.shape``. The returned spike signal is computed from pre-reset post-threshold voltage ``v_post``. Raises ------ KeyError If simulation context does not provide required entries ``t`` or ``dt``. AttributeError If required states are missing because :meth:`init_state` has not been called. TypeError If input/state values are not unit-compatible with expected pA/mV arithmetic. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() last_v = self.V.value ref_steps = self.refractory_step_count.value # Exact subthreshold propagation for one fixed simulation step. decay = u.math.exp(-dt / self.tau_m) i_total = self.sum_current_inputs(self.I_e + x, last_v) v_candidate = self.E_L + (last_v - self.E_L) * decay + (i_total / self.C_m) * self.tau_m * (1. - decay) delta_v = self.sum_delta_inputs(u.math.zeros_like(last_v)) v_candidate = v_candidate + delta_v if self.refractory_input: v_candidate = v_candidate + self.refractory_spike_buffer.value if self.V_min is not None: v_candidate = u.math.maximum(v_candidate, self.V_min) not_refractory = ref_steps == 0 v_post = u.math.where(not_refractory, v_candidate, last_v) if self.refractory_input: refr_decay = u.math.exp(-ref_steps * dt / self.tau_m) self.refractory_spike_buffer.value = u.math.where( not_refractory, u.math.zeros_like(self.refractory_spike_buffer.value), self.refractory_spike_buffer.value + delta_v * refr_decay ) ref_steps = u.math.where(not_refractory, ref_steps, ref_steps - 1) spike_cond = v_post >= self.V_th self.refractory_step_count.value = jax.lax.stop_gradient( u.math.where(spike_cond, self.ref_count, ref_steps) ) self.V.value = u.math.where(spike_cond, self.V_reset, v_post) self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t + dt, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return self.get_spike(v_post)