Source code for brainpy_state._nest.iaf_tum_2000

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable, Iterable

import brainstate
import braintools
import saiunit as bu
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer, propagator_exp

__all__ = [
    'iaf_tum_2000',
]


class iaf_tum_2000(NESTNeuron):
    r"""NEST-compatible ``iaf_tum_2000`` neuron model.

    Description
    -----------

    ``iaf_tum_2000`` is a leaky integrate-and-fire neuron with exponential
    postsynaptic currents and integrated Tsodyks-Markram short-term synaptic
    plasticity. The model extends :class:`iaf_psc_exp` by maintaining
    presynaptic resource states ``x`` (readily-releasable pool), ``y``
    (cleft/active fraction), and ``u`` (release probability), and emitting a
    per-spike ``spike_offset`` signal that encodes the jump in ``y`` at each
    spike event. This signal is used for receptor-1 coupling between
    ``iaf_tum_2000`` neurons, enabling dynamic synaptic efficacy.

    The implementation follows NEST ``models/iaf_tum_2000.{h,cpp}`` update
    ordering and event semantics exactly, including NEST-style buffered input
    handling and receptor-type routing.

    **1. Membrane and synaptic dynamics**

    Subthreshold voltage evolution follows the same equation as
    :class:`iaf_psc_exp`:

    .. math::

       \frac{dV_m}{dt} =
       -\frac{V_m - E_L}{\tau_m} +
       \frac{I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}} + I_e + I_0}{C_m},

    where ``I_0`` is the buffered current from the previous time step. Synaptic
    currents decay exponentially:

    .. math::

       \frac{dI_{\mathrm{syn,ex}}}{dt} = -\frac{I_{\mathrm{syn,ex}}}{\tau_{\mathrm{syn,ex}}},
       \qquad
       \frac{dI_{\mathrm{syn,in}}}{dt} = -\frac{I_{\mathrm{syn,in}}}{\tau_{\mathrm{syn,in}}}.

    Receptor-1 current input ``I_1`` is filtered through the excitatory kernel:

    .. math::

       I_{\mathrm{syn,ex}} \leftarrow I_{\mathrm{syn,ex}} + (1 - e^{-h/\tau_{\mathrm{syn,ex}}}) I_1,

    where ``h = dt`` is the simulation time step.

    **2. Tsodyks-Markram short-term plasticity on spike**

    When a neuron emits a spike at time ``t_spike``, the Tsodyks states are
    updated. Let ``t_last`` be the previous spike time (with NEST-compatible
    first-spike convention: ``t_last = 0`` if the internal last-spike time is
    negative, indicating no prior spike), and ``h_ts = t_spike - t_last``.

    Define propagators:

    .. math::

       P_{uu} = \begin{cases}
       0, & \tau_{\mathrm{fac}}=0 \\
       e^{-h_{ts}/\tau_{\mathrm{fac}}}, & \text{otherwise}
       \end{cases},
       \quad
       P_{yy} = e^{-h_{ts}/\tau_{\mathrm{psc}}},

    .. math::

       P_{zz} = \mathrm{expm1}(-h_{ts}/\tau_{\mathrm{rec}}) = e^{-h_{ts}/\tau_{\mathrm{rec}}} - 1,

    .. math::

       P_{xy} =
       \frac{P_{zz}\tau_{\mathrm{rec}} - (P_{yy}-1)\tau_{\mathrm{psc}}}{\tau_{\mathrm{psc}}-\tau_{\mathrm{rec}}}.

    With :math:`z = 1 - x - y` (inactive/recovered fraction), NEST performs
    state propagation in this exact order:

    .. math::

       u &\leftarrow u P_{uu}, \\
       x &\leftarrow x + P_{xy}y - P_{zz}z, \\
       y &\leftarrow y P_{yy},

    followed by utilization jump and resource transfer:

    .. math::

       u &\leftarrow u + U(1-u), \\
       \Delta y &= u x, \\
       x &\leftarrow x - \Delta y, \\
       y &\leftarrow y + \Delta y.

    ``spike_offset`` is set to :math:`\Delta y` on spike steps, zero otherwise.

    **3. NEST update ordering**

    Per time step, the model follows this precise sequence:

    1. Update membrane potential if not refractory (exact exponential propagator).
    2. Decay synaptic currents :math:`I_{\mathrm{syn,ex}}` and :math:`I_{\mathrm{syn,in}}`.
    3. Add filtered receptor-1 current to :math:`I_{\mathrm{syn,ex}}`.
    4. Add arriving spike inputs (positive weights to excitatory, non-positive to inhibitory).
    5. Perform threshold test (deterministic or escape-noise), assign refractory and reset.
    6. On emitted spike, update Tsodyks states (using the order above) and set ``spike_offset``.
    7. Buffer current inputs ``i_0`` and ``i_1`` for the next step.

    **4. Escape-noise threshold dynamics**

    Spike generation uses deterministic thresholding when :math:`\delta < 10^{-10}`:
    :math:`V_{\mathrm{rel}} \ge \theta`, where :math:`\theta = V_{th} - E_L`.

    For :math:`\delta > 0`, the model uses exponential hazard:

    .. math::

       \phi(V) = \rho \exp\left(\frac{V_{\mathrm{rel}} - \theta}{\delta}\right),

    with step spike probability :math:`p=\phi(V)\,h\,10^{-3}` (``h`` in ms,
    :math:`\phi` in ``1/s``). Stochastic decisions use ``numpy.random.random``.

    **5. Event semantics and receptor routing**

    The :meth:`update` method accepts ``spike_events`` as an iterable of event
    descriptors in one of these formats:

    - ``(receptor_type, weight)``
    - ``(receptor_type, weight, offset)``
    - ``(receptor_type, weight, offset, multiplicity)``
    - ``(receptor_type, weight, offset, multiplicity, sender_model)``
    - ``dict`` with keys ``receptor_type``/``receptor``, ``weight``, ``offset``,
      ``multiplicity``, ``sender_model``

    Receptors:

    - **Receptor 0** (DEFAULT): regular spike input, effective weight is
      ``weight * multiplicity``.
    - **Receptor 1** (TSODYKS): Tsodyks-coupled input, effective weight is
      ``weight * multiplicity * offset``, where ``offset`` is typically the
      ``spike_offset`` from the presynaptic ``iaf_tum_2000`` neuron.

    For receptor 1, the ``sender_model`` field must be ``"iaf_tum_2000"``
    (default assumption if not provided); otherwise a ``ValueError`` is raised,
    mirroring NEST's connection constraints.

    **6. Stability constraints and computational implications**

    - Construction validates: ``V_reset < V_th``, ``C_m > 0``, ``tau_m > 0``,
      ``tau_syn_ex > 0``, ``tau_syn_in > 0``, ``tau_psc > 0``, ``tau_rec > 0``,
      ``tau_fac >= 0``, ``t_ref >= 0``, ``rho >= 0``, ``delta >= 0``,
      ``x + y <= 1``, ``u ∈ [0,1]``.
    - Tsodyks state propagation uses the same singularity-free logic as NEST to
      handle ``tau_psc == tau_rec`` or ``tau_fac == 0`` cases gracefully.
    - Per-call cost is :math:`O(\prod \mathrm{varshape})` with vectorized
      NumPy operations in ``float64`` for coefficient evaluation.
    - Buffered current semantics match NEST ring-buffer timing: ``x`` and
      ``x_filtered`` supplied at step ``n`` are stored and consumed at step ``n+1``.

    Parameters
    ----------
    in_size : Size
        Population shape specification. All per-neuron parameters are broadcast
        to ``self.varshape``.
    E_L : ArrayLike, optional
        Resting potential :math:`E_L` in mV; scalar or array broadcastable to
        ``self.varshape``. Default is ``-70. * bu.mV``.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_m` in pF; broadcastable and strictly
        positive. Default is ``250. * bu.pF``.
    tau_m : ArrayLike, optional
        Membrane time constant :math:`\tau_m` in ms; broadcastable and
        strictly positive. Default is ``10. * bu.ms``.
    t_ref : ArrayLike, optional
        Absolute refractory period :math:`t_{\mathrm{ref}}` in ms;
        broadcastable and nonnegative. Converted to integer steps by
        ``ceil(t_ref / dt)``. Default is ``2. * bu.ms``.
    V_th : ArrayLike, optional
        Spike threshold :math:`V_{th}` in mV; broadcastable to
        ``self.varshape``. Default is ``-55. * bu.mV``.
    V_reset : ArrayLike, optional
        Post-spike reset potential :math:`V_{\mathrm{reset}}` in mV;
        broadcastable and must satisfy ``V_reset < V_th`` elementwise. Default
        is ``-70. * bu.mV``.
    tau_syn_ex : ArrayLike, optional
        Excitatory synaptic decay constant :math:`\tau_{\mathrm{syn,ex}}` in
        ms; broadcastable and strictly positive. Default is ``2. * bu.ms``.
    tau_syn_in : ArrayLike, optional
        Inhibitory synaptic decay constant :math:`\tau_{\mathrm{syn,in}}` in
        ms; broadcastable and strictly positive. Default is ``2. * bu.ms``.
    I_e : ArrayLike, optional
        Constant external injected current :math:`I_e` in pA; scalar or array
        broadcastable to ``self.varshape``. Default is ``0. * bu.pA``.
    rho : ArrayLike, optional
        Escape-noise base firing intensity :math:`\rho` in ``1/s``;
        broadcastable and nonnegative. Used only in stochastic mode
        (``delta > 0``). Default is ``0.01 / bu.second``.
    delta : ArrayLike, optional
        Escape-noise soft-threshold width :math:`\delta` in mV; broadcastable
        and nonnegative. ``delta == 0`` reproduces deterministic thresholding.
        Default is ``0. * bu.mV``.
    tau_fac : ArrayLike, optional
        Facilitation time constant :math:`\tau_{\mathrm{fac}}` in ms;
        broadcastable and nonnegative. ``tau_fac == 0`` disables facilitation
        (:math:`P_{uu}=0`). Default is ``1000. * bu.ms``.
    tau_psc : ArrayLike, optional
        Tsodyks postsynaptic current time constant :math:`\tau_{\mathrm{psc}}`
        in ms; broadcastable and strictly positive. Used in state propagators.
        Default is ``2. * bu.ms``.
    tau_rec : ArrayLike, optional
        Resource recovery time constant :math:`\tau_{\mathrm{rec}}` in ms;
        broadcastable and strictly positive. Default is ``400. * bu.ms``.
    U : ArrayLike, optional
        Utilization increment factor :math:`U` (dimensionless); broadcastable
        and must lie in ``[0, 1]``. Represents the per-spike increase in
        release probability. Default is ``0.5``.
    x : ArrayLike, optional
        Initial readily-releasable resource fraction (dimensionless);
        broadcastable. Must satisfy ``x + y <= 1`` and ``x >= 0``. Default is
        ``0.0``.
    y : ArrayLike, optional
        Initial cleft/active fraction (dimensionless); broadcastable. Must
        satisfy ``x + y <= 1`` and ``y >= 0``. Default is ``0.0``.
    u : ArrayLike, optional
        Initial release probability (dimensionless); broadcastable and must lie
        in ``[0, 1]``. Default is ``0.0``.
    V_initializer : Callable, optional
        Initializer for membrane state ``V`` used by :meth:`init_state`.
        Default is ``braintools.init.Constant(-70. * bu.mV)``.
    spk_fun : Callable, optional
        Surrogate spike nonlinearity used by :meth:`get_spike`. Default is
        ``braintools.surrogate.ReluGrad()``.
    spk_reset : str, optional
        Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
        ``'hard'`` matches NEST reset behavior. Default is ``'hard'``.
    ref_var : bool, optional
        If ``True``, allocates ``self.refractory`` (boolean) for external
        inspection of refractory state. Default is ``False``.
    name : str or None, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table::
       :header-rows: 1
       :widths: 14 26 14 16 30

       * - Parameter
         - Type / shape / unit
         - Default
         - Math symbol
         - Semantics
       * - ``in_size``
         - :class:`~brainstate.typing.Size`; scalar/tuple
         - required
         - --
         - Defines neuron population shape ``self.varshape``.
       * - ``E_L``
         - ArrayLike, broadcastable (mV)
         - ``-70. * bu.mV``
         - :math:`E_L`
         - Resting membrane potential.
       * - ``C_m``
         - ArrayLike, broadcastable (pF), ``> 0``
         - ``250. * bu.pF``
         - :math:`C_m`
         - Membrane capacitance in voltage integration.
       * - ``tau_m``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``10. * bu.ms``
         - :math:`\tau_m`
         - Membrane leak time constant.
       * - ``t_ref``
         - ArrayLike, broadcastable (ms), ``>= 0``
         - ``2. * bu.ms``
         - :math:`t_{\mathrm{ref}}`
         - Absolute refractory duration.
       * - ``V_th`` and ``V_reset``
         - ArrayLike, broadcastable (mV), with ``V_reset < V_th``
         - ``-55. * bu.mV``, ``-70. * bu.mV``
         - :math:`V_{th}`, :math:`V_{\mathrm{reset}}`
         - Threshold and post-spike reset voltages.
       * - ``tau_syn_ex`` and ``tau_syn_in``
         - ArrayLike, broadcastable (ms), each ``> 0``
         - ``2. * bu.ms``
         - :math:`\tau_{\mathrm{syn,ex}}`, :math:`\tau_{\mathrm{syn,in}}`
         - Exponential PSC decay constants.
       * - ``I_e``
         - ArrayLike, broadcastable (pA)
         - ``0. * bu.pA``
         - :math:`I_e`
         - Constant current injected every step.
       * - ``rho`` and ``delta``
         - ArrayLike, broadcastable; ``rho`` in ``1/s`` and ``delta`` in mV,
           both ``>= 0``
         - ``0.01 / bu.second``, ``0. * bu.mV``
         - :math:`\rho`, :math:`\delta`
         - Escape-noise hazard parameters.
       * - ``tau_fac``
         - ArrayLike, broadcastable (ms), ``>= 0``
         - ``1000. * bu.ms``
         - :math:`\tau_{\mathrm{fac}}`
         - Facilitation decay time constant; ``0`` disables.
       * - ``tau_psc``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``2. * bu.ms``
         - :math:`\tau_{\mathrm{psc}}`
         - Tsodyks PSC time constant.
       * - ``tau_rec``
         - ArrayLike, broadcastable (ms), ``> 0``
         - ``400. * bu.ms``
         - :math:`\tau_{\mathrm{rec}}`
         - Resource recovery time constant.
       * - ``U``
         - ArrayLike, broadcastable (dimensionless), ``∈ [0,1]``
         - ``0.5``
         - :math:`U`
         - Utilization increment per spike.
       * - ``x``
         - ArrayLike, broadcastable (dimensionless), ``x+y <= 1``
         - ``0.0``
         - :math:`x`
         - Initial readily-releasable fraction.
       * - ``y``
         - ArrayLike, broadcastable (dimensionless), ``x+y <= 1``
         - ``0.0``
         - :math:`y`
         - Initial cleft/active fraction.
       * - ``u``
         - ArrayLike, broadcastable (dimensionless), ``∈ [0,1]``
         - ``0.0``
         - :math:`u`
         - Initial release probability.
       * - ``V_initializer``
         - Callable
         - ``Constant(-70. * bu.mV)``
         - --
         - Initializer for membrane state ``V``.
       * - ``spk_fun``
         - Callable
         - ``ReluGrad()``
         - --
         - Surrogate function for output spikes.
       * - ``spk_reset``
         - ``str`` (typically ``'hard'``)
         - ``'hard'``
         - --
         - Reset behavior selection in base class.
       * - ``ref_var``
         - ``bool``
         - ``False``
         - --
         - Enables explicit boolean refractory state variable.
       * - ``name``
         - ``str`` or ``None``
         - ``None``
         - --
         - Optional instance name.

    Raises
    ------
    ValueError
        Raised at construction when any validated constraint is violated:
        ``V_reset >= V_th``, nonpositive ``C_m``/``tau_m``/synaptic time
        constants/Tsodyks time constants, negative ``tau_fac``/``t_ref``/``rho``/``delta``,
        ``U`` not in ``[0,1]``, ``u`` not in ``[0,1]``, or ``x + y > 1``.

    Examples
    --------
    .. code-block:: python

       >>> import brainstate
       >>> import saiunit as bu
       >>> from brainpy_state._nest.iaf_tum_2000 import iaf_tum_2000
       >>> brainstate.environ.set(dt=0.1 * bu.ms, t=0.0 * bu.ms)
       >>> neu = iaf_tum_2000(
       ...     in_size=(2,),
       ...     I_e=250. * bu.pA,
       ...     tau_fac=500. * bu.ms,
       ...     tau_rec=400. * bu.ms,
       ...     U=0.3
       ... )
       >>> neu.init_state()
       >>> out = neu.update(x=0. * bu.pA, x_filtered=0. * bu.pA)
       >>> out.shape
       (2,)

    Notes
    -----
    - Shares the exact exponential propagator implementation with
      :class:`iaf_psc_exp` (via :func:`propagator_exp` from ``_utils``).
    - The Tsodyks update order matches NEST ``iaf_tum_2000.cpp`` exactly to
      ensure identical dynamics in network simulations.
    - Receptor-1 connections require both presynaptic and postsynaptic neurons
      to be ``iaf_tum_2000`` models, enforced via runtime validation.
    - ``spike_offset`` can be recorded and monitored for debugging or analysis
      of dynamic synaptic efficacy.
    - The model is grid-based with one-step input buffering matching NEST's
      ring-buffer semantics.
    """

    __module__ = 'brainpy.state'

    RECEPTOR_TYPES = {
        'DEFAULT': 0,
        'TSODYKS': 1,
    }

    RECORDABLES = (
        'V_m',
        'I_syn_ex',
        'I_syn_in',
        'x',
        'y',
        'u',
        'spike_offset',
    )

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * bu.mV,
        C_m: ArrayLike = 250. * bu.pF,
        tau_m: ArrayLike = 10. * bu.ms,
        t_ref: ArrayLike = 2. * bu.ms,
        V_th: ArrayLike = -55. * bu.mV,
        V_reset: ArrayLike = -70. * bu.mV,
        tau_syn_ex: ArrayLike = 2. * bu.ms,
        tau_syn_in: ArrayLike = 2. * bu.ms,
        I_e: ArrayLike = 0. * bu.pA,
        rho: ArrayLike = 0.01 / bu.second,
        delta: ArrayLike = 0. * bu.mV,
        tau_fac: ArrayLike = 1000. * bu.ms,
        tau_psc: ArrayLike = 2. * bu.ms,
        tau_rec: ArrayLike = 400. * bu.ms,
        U: ArrayLike = 0.5,
        x: ArrayLike = 0.0,
        y: ArrayLike = 0.0,
        u: ArrayLike = 0.0,
        V_initializer: Callable = braintools.init.Constant(-70. * bu.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.tau_m = braintools.init.param(tau_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.rho = braintools.init.param(rho, self.varshape)
        self.delta = braintools.init.param(delta, self.varshape)

        self.tau_fac = braintools.init.param(tau_fac, self.varshape)
        self.tau_psc = braintools.init.param(tau_psc, self.varshape)
        self.tau_rec = braintools.init.param(tau_rec, self.varshape)
        self.U = braintools.init.param(U, self.varshape)
        self.x_init = braintools.init.param(x, self.varshape)
        self.y_init = braintools.init.param(y, self.varshape)
        self.u_init = braintools.init.param(u, self.varshape)

        self.V_initializer = V_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        # Precompute refractory step count
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = bu.math.asarray(bu.math.ceil(self.t_ref / dt), dtype=ditype)

    @property
    def receptor_types(self):
        r"""Return a dictionary of available receptor type labels.

        Returns
        -------
        dict
            Mapping from receptor name (str) to receptor ID (int):
            ``{'DEFAULT': 0, 'TSODYKS': 1}``.
        """
        return dict(self.RECEPTOR_TYPES)

    @property
    def recordables(self):
        r"""Return a list of state variable names available for recording.

        Returns
        -------
        list of str
            State variable names: ``['V_m', 'I_syn_ex', 'I_syn_in', 'x', 'y',
            'u', 'spike_offset']``. Note that the membrane potential is exposed
            as ``'V_m'`` (matching NEST convention) but stored internally as
            ``self.V``.
        """
        return list(self.RECORDABLES)

    @classmethod
    def _normalize_spike_receptor(cls, receptor):
        r"""Normalize receptor label to canonical integer ID.

        Converts string labels like ``'DEFAULT'``, ``'TSODYKS'``, ``'R0'``,
        ``'R1'``, or numeric strings/integers to the standard receptor IDs (0 or 1).

        Parameters
        ----------
        receptor : str or int
            Receptor label. Valid string labels (case-insensitive):

            - Receptor 0: ``'DEFAULT'``, ``'R0'``, ``'RECEPTOR0'``, ``'0'``
            - Receptor 1: ``'TSODYKS'``, ``'R1'``, ``'RECEPTOR1'``, ``'1'``

            Integer values must be 0 or 1.

        Returns
        -------
        int
            Canonical receptor ID: 0 or 1.

        Raises
        ------
        ValueError
            If ``receptor`` is an unrecognized string label or an integer not in
            ``{0, 1}``.
        """
        if isinstance(receptor, str):
            key = receptor.strip().upper()
            if key in ('DEFAULT', 'R0', 'RECEPTOR0', '0'):
                return 0
            if key in ('TSODYKS', 'R1', 'RECEPTOR1', '1'):
                return 1
            if key.isdigit():
                receptor = int(key)
            else:
                raise ValueError(f'Unknown receptor label: {receptor}')

        receptor = int(receptor)
        if receptor not in (0, 1):
            raise ValueError(f'Receptor type must be 0 or 1, got {receptor}.')
        return receptor

    def _validate_parameters(self):
        r"""Validate model parameters at construction time.

        Checks all parameter constraints to ensure physical consistency and
        numerical stability. Raises ``ValueError`` with a descriptive message if
        any constraint is violated.

        Raises
        ------
        ValueError
            If any of the following constraints are violated:

            - ``V_reset >= V_th``: Reset must be below threshold.
            - ``C_m <= 0``: Capacitance must be strictly positive.
            - ``tau_m <= 0``: Membrane time constant must be strictly positive.
            - ``tau_syn_ex <= 0`` or ``tau_syn_in <= 0``: Synaptic time constants
              must be strictly positive.
            - ``tau_psc <= 0`` or ``tau_rec <= 0``: Tsodyks time constants must be
              strictly positive.
            - ``tau_fac < 0``: Facilitation time constant must be nonnegative.
            - ``t_ref < 0``: Refractory time must be nonnegative.
            - ``U < 0`` or ``U > 1``: Utilization factor must be in ``[0, 1]``.
            - ``rho < 0``: Firing intensity must be nonnegative.
            - ``delta < 0``: Threshold width must be nonnegative.
            - ``x + y > 1.0``: Resource fractions must sum to at most 1.
            - ``u < 0`` or ``u > 1``: Initial release probability must be in ``[0, 1]``.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_m)):
            return

        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if np.any(self.C_m <= 0.0 * bu.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.tau_m <= 0.0 * bu.ms):
            raise ValueError('Membrane time constant must be strictly positive.')
        if np.any(self.tau_syn_ex <= 0.0 * bu.ms) or np.any(self.tau_syn_in <= 0.0 * bu.ms):
            raise ValueError('Synaptic time constants must be strictly positive.')
        if np.any(self.tau_psc <= 0.0 * bu.ms) or np.any(self.tau_rec <= 0.0 * bu.ms):
            raise ValueError('Tsodyks time constants tau_psc and tau_rec must be strictly positive.')
        if np.any(self.tau_fac < 0.0 * bu.ms):
            raise ValueError("'tau_fac' must be >= 0.")
        if np.any(self.t_ref < 0.0 * bu.ms):
            raise ValueError('Refractory time must not be negative.')
        if np.any(self.U < 0.0) or np.any(self.U > 1.0):
            raise ValueError("'U' must be in [0,1].")
        if np.any(self.rho < 0.0 * (1 / bu.second)):
            raise ValueError('Stochastic firing intensity rho must not be negative.')
        if np.any(self.delta < 0.0 * bu.mV):
            raise ValueError('Threshold width delta must not be negative.')

        if np.any(self.x_init + self.y_init > 1.0):
            raise ValueError('x + y must be <= 1.0.')
        if np.any(self.u_init < 0.0) or np.any(self.u_init > 1.0):
            raise ValueError("'u' must be in [0,1].")

[docs] def init_state(self, **kwargs): r"""Initialize all neuron state variables. Creates and allocates state variables for membrane potential, synaptic currents, refractory counter, Tsodyks-Markram plasticity states, and buffered inputs. All states are allocated as ``brainstate.HiddenState`` or ``brainstate.ShortTermState`` with shape ``self.varshape``. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables created: - ``V`` : :class:`~brainstate.HiddenState` (mV) Membrane potential, initialized via ``self.V_initializer``. - ``i_syn_ex`` : :class:`~brainstate.ShortTermState` (pA) Excitatory synaptic current, initialized to zero. - ``i_syn_in`` : :class:`~brainstate.ShortTermState` (pA) Inhibitory synaptic current, initialized to zero. - ``i_0`` : :class:`~brainstate.ShortTermState` (pA) Buffered receptor-0 current input, initialized to zero. - ``i_1`` : :class:`~brainstate.ShortTermState` (pA) Buffered receptor-1 current input, initialized to zero. - ``refractory_step_count`` : :class:`~brainstate.ShortTermState` (int32) Remaining refractory steps, initialized to zero. - ``last_spike_time`` : :class:`~brainstate.ShortTermState` (ms) Time of last emitted spike, initialized to ``-1e7 * bu.ms`` (no prior spike). - ``x`` : :class:`~brainstate.ShortTermState` (dimensionless) Readily-releasable resource fraction, initialized to ``self.x_init``. - ``y`` : :class:`~brainstate.ShortTermState` (dimensionless) Cleft/active fraction, initialized to ``self.y_init``. - ``u`` : :class:`~brainstate.ShortTermState` (dimensionless) Release probability, initialized to ``self.u_init``. - ``spike_offset`` : :class:`~brainstate.ShortTermState` (dimensionless) Per-spike :math:`\Delta y` signal for receptor-1 coupling, initialized to zero. - ``refractory`` : :class:`~brainstate.ShortTermState` (bool), optional Boolean refractory flag, allocated only if ``ref_var=True``. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() V = braintools.init.param(self.V_initializer, self.varshape) self.V = brainstate.HiddenState(V) self.i_syn_ex = brainstate.ShortTermState(bu.math.zeros(self.varshape, dtype=dftype) * bu.pA) self.i_syn_in = brainstate.ShortTermState(bu.math.zeros(self.varshape, dtype=dftype) * bu.pA) self.i_0 = brainstate.ShortTermState(bu.math.zeros(self.varshape, dtype=dftype) * bu.pA) self.i_1 = brainstate.ShortTermState(bu.math.zeros(self.varshape, dtype=dftype) * bu.pA) self.refractory_step_count = brainstate.ShortTermState(bu.math.full(self.varshape, 0, dtype=ditype)) self.last_spike_time = brainstate.ShortTermState(bu.math.full(self.varshape, -1e7 * bu.ms)) self.x = brainstate.ShortTermState( bu.math.asarray(np.broadcast_to(np.asarray(self.x_init, dtype=dftype), self.varshape).copy()) ) self.y = brainstate.ShortTermState( bu.math.asarray(np.broadcast_to(np.asarray(self.y_init, dtype=dftype), self.varshape).copy()) ) self.u = brainstate.ShortTermState( bu.math.asarray(np.broadcast_to(np.asarray(self.u_init, dtype=dftype), self.varshape).copy()) ) self.spike_offset = brainstate.ShortTermState(bu.math.zeros(self.varshape, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory) # Pre-compute propagator coefficients (constant for a given dt). self._precompute_propagators()
def _precompute_propagators(self): """Pre-compute NEST propagator coefficients and cached parameters from dt.""" dt_q = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() h = float(bu.math.asarray(dt_q / bu.ms)) tau_ex_np = np.asarray(bu.math.asarray(self.tau_syn_ex / bu.ms), dtype=dftype) tau_in_np = np.asarray(bu.math.asarray(self.tau_syn_in / bu.ms), dtype=dftype) tau_m_np = np.asarray(bu.math.asarray(self.tau_m / bu.ms), dtype=dftype) C_m_np = np.asarray(bu.math.asarray(self.C_m / bu.pF), dtype=dftype) self._P11_ex = jnp.asarray(np.exp(-h / tau_ex_np)) self._P11_in = jnp.asarray(np.exp(-h / tau_in_np)) self._P22 = jnp.asarray(np.exp(-h / tau_m_np)) self._P21_ex = jnp.asarray(propagator_exp(tau_ex_np, tau_m_np, C_m_np, h)) self._P21_in = jnp.asarray(propagator_exp(tau_in_np, tau_m_np, C_m_np, h)) self._P20 = jnp.asarray(tau_m_np / C_m_np * (1.0 - np.exp(-h / tau_m_np))) self._h = h # Stochastic threshold cached values. delta_np = np.asarray(bu.math.asarray(self.delta / bu.mV), dtype=dftype) rho_np = np.asarray(bu.math.asarray(self.rho / (1 / bu.second)), dtype=dftype) self._delta_np = jnp.asarray(delta_np) self._rho_np = jnp.asarray(rho_np) self._deterministic = self._delta_np < 1e-10 self._delta_safe = jnp.where(self._deterministic, 1.0, self._delta_np) # Tsodyks cached parameters (dimensionless ms values). self._tau_fac_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.tau_fac / bu.ms), dtype=dftype)) self._tau_psc_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.tau_psc / bu.ms), dtype=dftype)) self._tau_rec_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.tau_rec / bu.ms), dtype=dftype)) self._U_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.U), dtype=dftype)) # Refractory step count as JAX integer array. self._ref_count_jnp = jnp.asarray( np.asarray(bu.math.asarray(bu.math.ceil(self.t_ref / dt_q)), dtype=ditype) ) # Pre-compute unit-stripped static JAX parameter arrays for JIT-compatible update(). self._E_L_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.E_L / bu.mV), dtype=dftype)) self._theta_jnp = jnp.asarray(np.asarray(bu.math.asarray((self.V_th - self.E_L) / bu.mV), dtype=dftype)) self._V_reset_rel_jnp = jnp.asarray(np.asarray(bu.math.asarray((self.V_reset - self.E_L) / bu.mV), dtype=dftype)) self._I_e_jnp = jnp.asarray(np.asarray(bu.math.asarray(self.I_e / bu.pA), dtype=dftype))
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute surrogate spike output given membrane potential. Applies the surrogate spike function (``self.spk_fun``) to a normalized voltage that ranges from 0 (at reset) to 1 (at threshold). This enables differentiable spike computation for gradient-based learning. Parameters ---------- V : ArrayLike or None, optional Membrane potential in mV; broadcastable to ``self.varshape``. If ``None``, uses ``self.V.value``. Default is ``None``. Returns ------- out : dict Surrogate spike activation, shape matching the input ``V`` (or ``self.V.value``). The output is typically in ``[0, 1]`` for sub-threshold voltages and close to 1 for supra-threshold voltages, depending on the surrogate function used. Notes ----- Voltage normalization: .. math:: v_{\mathrm{scaled}} = \frac{V - V_{th}}{V_{th} - V_{\mathrm{reset}}}. The surrogate function ``self.spk_fun`` (default ``braintools.surrogate.ReluGrad()``) is then applied to ``v_scaled``, providing a differentiable approximation of the Heaviside step function. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _parse_spike_events(self, spike_events: Iterable, state_shape): r"""Parse external spike events into excitatory and inhibitory weights. Processes each event descriptor in ``spike_events``, validates receptor types and sender models, computes effective weights including multiplicity and offset factors, and routes them by sign to excitatory or inhibitory channels. Parameters ---------- spike_events : Iterable or None Collection of event descriptors (see :meth:`update` for format). state_shape : tuple of int Target shape for broadcasting weight arrays. Returns ------- w_ex : np.ndarray Excitatory weights in pA, shape ``state_shape``, dtype ``float64``. w_in : np.ndarray Inhibitory weights in pA, shape ``state_shape``, dtype ``float64``. Raises ------ ValueError If receptor-1 events have ``sender_model != "iaf_tum_2000"``, or if event format is invalid. Notes ----- Effective weights are computed as: - Receptor 0: ``weight * multiplicity`` - Receptor 1: ``weight * multiplicity * offset`` Positive weights route to ``w_ex``, non-positive to ``w_in``. """ dftype = brainstate.environ.dftype() w_ex = np.zeros(state_shape, dtype=dftype) w_in = np.zeros(state_shape, dtype=dftype) if spike_events is None: return w_ex, w_in for ev in spike_events: sender_model = 'iaf_tum_2000' multiplicity = 1.0 offset = 1.0 if isinstance(ev, dict): receptor = ev.get('receptor_type', ev.get('receptor', 0)) weight = ev.get('weight', 0.0 * bu.pA) offset = ev.get('offset', 1.0) multiplicity = ev.get('multiplicity', 1.0) sender_model = ev.get('sender_model', 'iaf_tum_2000') else: if len(ev) == 2: receptor, weight = ev elif len(ev) == 3: receptor, weight, offset = ev elif len(ev) == 4: receptor, weight, offset, multiplicity = ev elif len(ev) == 5: receptor, weight, offset, multiplicity, sender_model = ev else: raise ValueError('Spike event tuples must have length 2, 3, 4, or 5.') receptor_id = self._normalize_spike_receptor(receptor) s = np.broadcast_to(np.asarray(bu.math.asarray(weight / bu.pA), dtype=dftype), state_shape) s = s * np.broadcast_to(np.asarray(bu.math.asarray(multiplicity), dtype=dftype), state_shape) if receptor_id == 1: if sender_model != 'iaf_tum_2000': raise ValueError( 'For receptor_type 1 in iaf_tum_2000, pre-synaptic neuron must also be of type iaf_tum_2000.' ) s = s * np.broadcast_to(np.asarray(bu.math.asarray(offset), dtype=dftype), state_shape) w_ex += np.where(s > 0.0, s, 0.0) w_in += np.where(s > 0.0, 0.0, s) return w_ex, w_in def _parse_registered_spike_inputs(self, state_shape): r"""Parse registered delta inputs into excitatory and inhibitory weights. Processes inputs previously registered via :meth:`add_delta_input` (inherited from :class:`~brainpy_state._base.Dynamics`), extracts receptor labels from keys, and routes by sign to excitatory or inhibitory channels. Parameters ---------- state_shape : tuple of int Target shape for broadcasting weight arrays. Returns ------- w_ex : np.ndarray Excitatory weights in pA, shape ``state_shape``, dtype ``float64``. w_in : np.ndarray Inhibitory weights in pA, shape ``state_shape``, dtype ``float64``. Notes ----- Keys in ``self.delta_inputs`` may optionally include a receptor label prefix (e.g., ``'TSODYKS // proj_0'``). If present, the label is extracted and normalized via :meth:`_normalize_spike_receptor`; otherwise defaults to receptor 0. Values are either callables (invoked and then removed) or direct ArrayLike values. """ dftype = brainstate.environ.dftype() w_ex = np.zeros(state_shape, dtype=dftype) w_in = np.zeros(state_shape, dtype=dftype) if self.delta_inputs is None: return w_ex, w_in for key in tuple(self.delta_inputs.keys()): val = self.delta_inputs[key] if callable(val): val = val() else: self.delta_inputs.pop(key) label = None if ' // ' in key: label, _ = key.split(' // ', maxsplit=1) receptor = 0 if label is None else self._normalize_spike_receptor(label) s = np.broadcast_to(np.asarray(bu.math.asarray(val / bu.pA), dtype=dftype), state_shape) if receptor == 0: w_ex += np.where(s > 0.0, s, 0.0) w_in += np.where(s > 0.0, 0.0, s) else: w_ex += np.where(s > 0.0, s, 0.0) w_in += np.where(s > 0.0, 0.0, s) return w_ex, w_in
[docs] def update(self, x=0. * bu.pA, x_filtered=0. * bu.pA, spike_events=None, _w_ex_jnp=None, _w_in_jnp=None): r"""Advance the neuron state by one simulation step. Performs a complete integration step following NEST ``iaf_tum_2000`` update order: membrane propagation (if not refractory), synaptic current decay, filtered-current injection, spike input addition, threshold test, Tsodyks state update on spike emission, and input buffering for the next step. Parameters ---------- x : ArrayLike, optional Current input in pA for receptor-0 (standard current port). Scalar or array broadcastable to ``self.varshape``. The value is buffered and applied in the next step (NEST ring-buffer semantics). Default is ``0. * bu.pA``. x_filtered : ArrayLike, optional Current input in pA for receptor-1. It is buffered to ``self.i_1`` and injected through excitatory exponential filtering at the next update step via ``(1 - P11_ex) * i_1``. Scalar or array broadcastable to ``self.varshape``. Default is ``0. * bu.pA``. spike_events : Iterable or None, optional Collection of spike event descriptors for direct spike input. Each event can be: - ``(receptor_type, weight)`` - ``(receptor_type, weight, offset)`` - ``(receptor_type, weight, offset, multiplicity)`` - ``(receptor_type, weight, offset, multiplicity, sender_model)`` - ``dict`` with keys ``receptor_type``/``receptor`` (int or str), ``weight`` (ArrayLike in pA), ``offset`` (float, default ``1.0``), ``multiplicity`` (float, default ``1.0``), ``sender_model`` (str, default ``"iaf_tum_2000"``). **Receptor types:** - ``0`` (or ``'DEFAULT'``): regular spike input, effective weight is ``weight * multiplicity``. - ``1`` (or ``'TSODYKS'``): Tsodyks-coupled input, effective weight is ``weight * multiplicity * offset``, where ``offset`` is typically the ``spike_offset`` from the presynaptic neuron. For receptor ``1``, ``sender_model`` must be ``"iaf_tum_2000"``; otherwise a ``ValueError`` is raised. Positive effective weights route to excitatory channel, non-positive to inhibitory channel. Default is ``None`` (no events). Returns ------- out : jax.Array Surrogate spike output returned by :meth:`get_spike`. The output is elementwise over the neuron state shape (and batch axis, if initialized). For emitted spikes, the voltage argument to :meth:`get_spike` is nudged above threshold by ``1e-12`` mV to preserve positive spike activation under hard reset. Raises ------ ValueError If provided inputs cannot be broadcast to the internal state shape, or if receptor-1 events have ``sender_model != "iaf_tum_2000"``. Notes ----- **Update order (following NEST ``iaf_tum_2000.cpp``):** 1. **Membrane propagation**: If not refractory, update :math:`V_{\mathrm{rel}}` using exact exponential propagators (same as :class:`iaf_psc_exp`). 2. **Synaptic decay**: Exponentially decay ``i_syn_ex`` and ``i_syn_in``. 3. **Filtered current injection**: Add ``(1 - exp(-h/tau_syn_ex)) * i_1`` to ``i_syn_ex``. 4. **Spike input addition**: Add arriving spike inputs (from ``spike_events`` and registered delta inputs) to ``i_syn_ex`` and ``i_syn_in`` by sign. 5. **Threshold test**: Determine spike condition (deterministic or escape-noise), assign refractory counter, and reset voltage. 6. **Tsodyks update**: On emitted spike, update ``(u, x, y)`` states in NEST order, compute :math:`\Delta y`, and set ``spike_offset``. 7. **Buffer inputs**: Store ``x`` and ``x_filtered`` for next step. **Tsodyks state update on spike:** When a spike is emitted, inter-spike interval ``h_ts = t_spike - t_last`` is computed (with ``t_last = 0`` if ``last_spike_time < 0``). Propagators are: .. math:: P_{uu} = \begin{cases} 0, & \tau_{\mathrm{fac}}=0 \\ e^{-h_{ts}/\tau_{\mathrm{fac}}}, & \text{otherwise} \end{cases}, \quad P_{yy} = e^{-h_{ts}/\tau_{\mathrm{psc}}}, \quad P_{zz} = e^{-h_{ts}/\tau_{\mathrm{rec}}} - 1, .. math:: P_{xy} = \frac{P_{zz}\tau_{\mathrm{rec}} - (P_{yy}-1)\tau_{\mathrm{psc}}}{\tau_{\mathrm{psc}}-\tau_{\mathrm{rec}}}. Then states update as: .. math:: u \leftarrow u P_{uu}, \quad x \leftarrow x + P_{xy}y - P_{zz}(1-x-y), \quad y \leftarrow y P_{yy}, \\ u \leftarrow u + U(1-u), \quad \Delta y = u x, \quad x \leftarrow x - \Delta y, \quad y \leftarrow y + \Delta y. ``spike_offset`` is set to :math:`\Delta y` on spike, zero otherwise. **Performance:** Per-step computational cost is :math:`O(\prod \mathrm{varshape})` with vectorized NumPy operations in ``float64`` for coefficient computation and state updates. """ t = brainstate.environ.get('t') h = self._h # pre-computed Python float, safe under JIT t_ms = bu.math.asarray(t / bu.ms) # JAX scalar, traced under JIT ditype = brainstate.environ.ditype() # Pre-computed static JAX parameter arrays (no unit stripping per step). E_L = self._E_L_jnp theta = self._theta_jnp V_reset_rel = self._V_reset_rel_jnp I_e = self._I_e_jnp tau_fac = self._tau_fac_jnp tau_psc = self._tau_psc_jnp tau_rec = self._tau_rec_jnp U = self._U_jnp # Pre-computed propagator coefficients. P11_ex = self._P11_ex P11_in = self._P11_in P22 = self._P22 P21_ex = self._P21_ex P21_in = self._P21_in P20 = self._P20 # Read state variables as JAX arrays (unit-stripped). V_rel = bu.math.asarray(self.V.value / bu.mV) - E_L i_0 = bu.math.asarray(self.i_0.value / bu.pA) i_1 = bu.math.asarray(self.i_1.value / bu.pA) i_syn_ex = bu.math.asarray(self.i_syn_ex.value / bu.pA) i_syn_in = bu.math.asarray(self.i_syn_in.value / bu.pA) r = self.refractory_step_count.value x_state = self.x.value y_state = self.y.value u_state = self.u.value last_spike_prev = bu.math.asarray(self.last_spike_time.value / bu.ms) # Spike event handling: JIT-compatible path takes pre-computed JAX arrays; # Python path parses spike_events dicts/tuples (cannot run inside jit). if _w_ex_jnp is not None or _w_in_jnp is not None: w_ex = _w_ex_jnp if _w_ex_jnp is not None else jnp.zeros(self.varshape) w_in = _w_in_jnp if _w_in_jnp is not None else jnp.zeros(self.varshape) else: ev_ex, ev_in = self._parse_spike_events(spike_events, self.varshape) reg_ex, reg_in = self._parse_registered_spike_inputs(self.varshape) w_ex = jnp.asarray(ev_ex + reg_ex) w_in = jnp.asarray(ev_in + reg_in) # Buffer next-step inputs (ring-buffer semantics, one-step delay). # The `+ jnp.zeros(self.varshape)` broadcasts scalar inputs to varshape. i_0_next = bu.math.asarray(self.sum_current_inputs(x, self.V.value) / bu.pA) + jnp.zeros(self.varshape) i_1_next = bu.math.asarray(x_filtered / bu.pA) + jnp.zeros(self.varshape) # 1. Membrane propagation (skip if refractory). not_refractory = r == 0 V_candidate = V_rel * P22 + i_syn_ex * P21_ex + i_syn_in * P21_in + (I_e + i_0) * P20 V_rel = jnp.where(not_refractory, V_candidate, V_rel) r = jnp.where(not_refractory, r, r - 1) # 2. Synaptic decay. i_syn_ex = i_syn_ex * P11_ex i_syn_in = i_syn_in * P11_in # 3. Filtered receptor-1 current injection. i_syn_ex = i_syn_ex + (1.0 - P11_ex) * i_1 # 4. Arriving spike inputs. i_syn_ex = i_syn_ex + w_ex i_syn_in = i_syn_in + w_in # 5. Threshold test (deterministic or escape-noise). det_spike = V_rel >= theta phi = self._rho_np * jnp.exp((V_rel - theta) / self._delta_safe) stoch_spike = jnp.asarray(np.random.random(size=self.varshape)) < phi * h * 1e-3 spike_cond = jnp.where(self._deterministic, det_spike, stoch_spike) r = jnp.where(spike_cond, self._ref_count_jnp, r) V_before_reset = V_rel V_rel = jnp.where(spike_cond, V_reset_rel, V_rel) # 6. Tsodyks-Markram state update on spike. t_last = jnp.where(last_spike_prev < 0.0, 0.0, last_spike_prev) t_spike = t_ms + h h_tsodyks = t_spike - t_last tau_fac_safe = jnp.where(tau_fac == 0.0, 1.0, tau_fac) Puu = jnp.where(tau_fac == 0.0, 0.0, jnp.exp(-h_tsodyks / tau_fac_safe)) Pyy = jnp.exp(-h_tsodyks / tau_psc) Pzz = jnp.expm1(-h_tsodyks / tau_rec) Pxy = (Pzz * tau_rec - (Pyy - 1.0) * tau_psc) / (tau_psc - tau_rec) z_state = 1.0 - x_state - y_state u_prop = u_state * Puu x_prop = x_state + Pxy * y_state - Pzz * z_state y_prop = y_state * Pyy u_jump = u_prop + U * (1.0 - u_prop) delta_y_tsp = u_jump * x_prop x_new = x_prop - delta_y_tsp y_new = y_prop + delta_y_tsp x_state = jnp.where(spike_cond, x_new, x_state) y_state = jnp.where(spike_cond, y_new, y_state) u_state = jnp.where(spike_cond, u_jump, u_state) spike_offset = jnp.where(spike_cond, delta_y_tsp, 0.0) last_spike_next = jnp.where(spike_cond, t_spike, last_spike_prev) # 7. Write back state. self.V.value = (V_rel + E_L) * bu.mV self.i_syn_ex.value = i_syn_ex * bu.pA self.i_syn_in.value = i_syn_in * bu.pA self.i_0.value = i_0_next * bu.pA self.i_1.value = i_1_next * bu.pA self.refractory_step_count.value = jnp.asarray(r, dtype=ditype) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_next * bu.ms) self.x.value = x_state self.y.value = y_state self.u.value = u_state self.spike_offset.value = spike_offset if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) V_out = jnp.where(spike_cond, theta + E_L + 1e-12, V_before_reset + E_L) return self.get_spike(V_out * bu.mV)