Source code for brainpy_state._nest.iaf_cond_alpha

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'iaf_cond_alpha',
]


class iaf_cond_alpha(NESTNeuron):
    r"""Leaky integrate-and-fire model with alpha-shaped conductance synapses.

    Description
    -----------

    ``iaf_cond_alpha`` is a conductance-based leaky integrate-and-fire neuron with

    * hard threshold,
    * fixed absolute refractory period,
    * alpha-shaped excitatory and inhibitory synaptic conductances (second-order kinetics),
    * no adaptation variables.

    This implementation follows NEST ``iaf_cond_alpha`` dynamics and update order,
    using NEST C++ model behavior as the source of truth.

    **1. Membrane Potential and Synaptic Currents**

    The membrane potential evolves according to

    .. math::

       \frac{dV_\mathrm{m}}{dt} =
       \frac{-g_\mathrm{L}(V_\mathrm{m}-E_\mathrm{L})
             - I_\mathrm{syn}
             + I_\mathrm{e}
             + I_\mathrm{stim}}
            {C_\mathrm{m}}

    with

    .. math::

       I_\mathrm{syn}
       = I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}}
       = g_\mathrm{ex}(V_\mathrm{m}-E_\mathrm{ex})
       + g_\mathrm{in}(V_\mathrm{m}-E_\mathrm{in}) .

    **2. Alpha-Shaped Conductance Kinetics**

    Alpha conductances use two coupled state variables per channel:

    .. math::

       \frac{d\,dg_\mathrm{ex}}{dt} = -\frac{dg_\mathrm{ex}}{\tau_{\mathrm{syn,ex}}},
       \qquad
       \frac{d g_\mathrm{ex}}{dt}
       = dg_\mathrm{ex} - \frac{g_\mathrm{ex}}{\tau_{\mathrm{syn,ex}}},

    .. math::

       \frac{d\,dg_\mathrm{in}}{dt} = -\frac{dg_\mathrm{in}}{\tau_{\mathrm{syn,in}}},
       \qquad
       \frac{d g_\mathrm{in}}{dt}
       = dg_\mathrm{in} - \frac{g_\mathrm{in}}{\tau_{\mathrm{syn,in}}}.

    A presynaptic spike with weight :math:`w` causes an instantaneous jump at
    the end of the simulation step. Positive/negative weights map to
    excitatory/inhibitory channels:

    .. math::

       w > 0 \Rightarrow dg_\mathrm{ex} \leftarrow dg_\mathrm{ex} + \frac{e}{\tau_{\mathrm{syn,ex}}} w,

    .. math::

       w < 0 \Rightarrow dg_\mathrm{in} \leftarrow dg_\mathrm{in} + \frac{e}{\tau_{\mathrm{syn,in}}} |w|.

    The normalization factor :math:`e/\tau` ensures the conductance peak matches
    the weight magnitude (in nS).

    **3. Spike Emission and Refractory Mechanism**

    A spike is emitted when :math:`V_\mathrm{m} \ge V_\mathrm{th}` at the end of
    a simulation step. On spike:

    * :math:`V_\mathrm{m}` is reset to :math:`V_\mathrm{reset}`,
    * refractory counter is set to :math:`\lceil t_\mathrm{ref}/dt \rceil`,
    * spike time is recorded as :math:`t + dt`.

    During absolute refractory period:

    * effective membrane potential in current computation is clamped to :math:`V_\mathrm{reset}`,
    * :math:`dV_\mathrm{m}/dt = 0`,
    * conductances continue to decay.

    **4. Numerical Integration and Update Order**

    NEST integrates this model with adaptive RKF45. This implementation mirrors
    that behavior with an RKF45(4,5) integrator and persistent internal step size.
    The discrete-time update order is:

    1. Integrate continuous dynamics on :math:`(t, t+dt]` using RKF45 with adaptive substeps.
    2. Apply refractory countdown / threshold test / reset and spike emission.
    3. Add synaptic conductance jumps from spike inputs arriving this step.
    4. Store external current input as :math:`I_\mathrm{stim}` for the next step.

    The one-step delayed application of current input (``I_stim`` buffer) is
    intentional and matches NEST's ring-buffer update semantics.

    Parameters
    ----------
    in_size : tuple of int or int
        Shape of the neuron population. Can be an integer for 1D populations or
        a tuple for multi-dimensional populations.
    E_L : ArrayLike, optional
        Leak reversal potential :math:`E_\mathrm{L}`. Default: -70 mV.
    C_m : ArrayLike, optional
        Membrane capacitance :math:`C_\mathrm{m}`. Must be strictly positive.
        Default: 250 pF.
    t_ref : ArrayLike, optional
        Absolute refractory period :math:`t_\mathrm{ref}`. Must be non-negative.
        Default: 2 ms.
    V_th : ArrayLike, optional
        Spike threshold :math:`V_\mathrm{th}`. Must be larger than ``V_reset``.
        Default: -55 mV.
    V_reset : ArrayLike, optional
        Reset potential :math:`V_\mathrm{reset}`. Must be smaller than ``V_th``.
        Default: -60 mV.
    E_ex : ArrayLike, optional
        Excitatory reversal potential :math:`E_\mathrm{ex}`. Default: 0 mV.
    E_in : ArrayLike, optional
        Inhibitory reversal potential :math:`E_\mathrm{in}`. Default: -85 mV.
    g_L : ArrayLike, optional
        Leak conductance :math:`g_\mathrm{L}`. Must be strictly positive.
        Default: 16.6667 nS (yields :math:`\tau_\mathrm{m} = 15` ms with default ``C_m``).
    tau_syn_ex : ArrayLike, optional
        Excitatory alpha time constant :math:`\tau_{\mathrm{syn,ex}}`. Must be
        strictly positive. Default: 0.2 ms.
    tau_syn_in : ArrayLike, optional
        Inhibitory alpha time constant :math:`\tau_{\mathrm{syn,in}}`. Must be
        strictly positive. Default: 2.0 ms.
    I_e : ArrayLike, optional
        Constant external current :math:`I_\mathrm{e}`. Default: 0 pA.
    gsl_error_tol : ArrayLike, optional
        Unitless local RKF45 error tolerance, broadcastable and strictly positive.
        Default: 1e-3.
    V_initializer : Callable, optional
        Initializer for membrane potential. Default: Constant(-70 mV).
    g_ex_initializer : Callable, optional
        Initializer for excitatory conductance. Default: Constant(0 nS).
    g_in_initializer : Callable, optional
        Initializer for inhibitory conductance. Default: Constant(0 nS).
    spk_fun : Callable, optional
        Surrogate gradient function for spike generation (differentiable approximation).
        Default: ReluGrad().
    spk_reset : str, optional
        Spike reset mode. ``'hard'`` uses stop_gradient (matches NEST behavior),
        ``'soft'`` allows gradients through reset. Default: ``'hard'``.
    ref_var : bool, optional
        If True, expose ``refractory`` state variable as boolean indicator.
        Default: False.
    name : str, optional
        Name of the neuron group.

    Parameter Mapping
    -----------------

    ==================== ================== ========================================
    **Parameter**        **Default**        **Math equivalent**
    ==================== ================== ========================================
    ``in_size``          (required)         --
    ``E_L``              -70 mV             :math:`E_\mathrm{L}`
    ``C_m``              250 pF             :math:`C_\mathrm{m}`
    ``t_ref``            2 ms               :math:`t_\mathrm{ref}`
    ``V_th``             -55 mV             :math:`V_\mathrm{th}`
    ``V_reset``          -60 mV             :math:`V_\mathrm{reset}`
    ``E_ex``             0 mV               :math:`E_\mathrm{ex}`
    ``E_in``             -85 mV             :math:`E_\mathrm{in}`
    ``g_L``              16.6667 nS         :math:`g_\mathrm{L}`
    ``tau_syn_ex``       0.2 ms             :math:`\tau_{\mathrm{syn,ex}}`
    ``tau_syn_in``       2.0 ms             :math:`\tau_{\mathrm{syn,in}}`
    ``I_e``              0 pA               :math:`I_\mathrm{e}`
    ``gsl_error_tol``    1e-3               --
    ``V_initializer``    Constant(-70 mV)   --
    ``g_ex_initializer`` Constant(0 nS)     --
    ``g_in_initializer`` Constant(0 nS)     --
    ``spk_fun``          ReluGrad()         --
    ``spk_reset``        ``'hard'``         --
    ``ref_var``          ``False``          --
    ==================== ================== ========================================

    State Variables
    ---------------

    ========================= ================================================================
    **State variable**        **Description**
    ========================= ================================================================
    ``V``                     Membrane potential :math:`V_\mathrm{m}`
    ``dg_ex``                 Excitatory alpha auxiliary state
    ``g_ex``                  Excitatory conductance :math:`g_\mathrm{ex}`
    ``dg_in``                 Inhibitory alpha auxiliary state
    ``g_in``                  Inhibitory conductance :math:`g_\mathrm{in}`
    ``last_spike_time``       Last spike time (recorded at :math:`t+dt`)
    ``refractory_step_count`` Remaining refractory grid steps
    ``integration_step``      Internal RKF45 step-size state (persistent)
    ``I_stim``                Buffered current applied in next step
    ``refractory``            Optional boolean refractory indicator (if ``ref_var=True``)
    ========================= ================================================================

    **Sends:**
    ``SpikeEvent`` (conceptually; represented as returned spike tensor from ``update``).

    **Receives:**
    Signed spike-weight conductance increments through ``add_delta_input``.
    - External current input through ``x`` in :meth:`update` (one-step delayed).

    Raises
    ------
    ValueError
        If ``V_reset >= V_th``, ``C_m <= 0``, ``t_ref < 0``, or any time constants
        are non-positive.

    Notes
    -----

    - Defaults follow NEST C++ source for ``iaf_cond_alpha`` (``models/iaf_cond_alpha.h/.cpp``).
    - Synaptic spike weights are interpreted in conductance units (nS), with
      positive/negative sign selecting excitatory/inhibitory channel.
    - The alpha shape produces a smoother conductance transient than single exponentials,
      peaking at :math:`t = \tau` after a spike.
    - During refractory period, the effective voltage used for current computation is
      clamped, but the actual ``V`` state continues to be updated (remains at reset value).

    References
    ----------
    .. [1] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
           the large, fluctuating synaptic conductance state typical of
           neocortical neurons in vivo. Journal of Computational Neuroscience,
           16:159-175. DOI: https://doi.org/10.1023/B:JCNS.0000014108.03012.81
    .. [2] Bernander O, Douglas RJ, Martin KAC, Koch C (1991). Synaptic
           background activity influences spatiotemporal integration in single
           pyramidal cells. PNAS, 88(24):11569-11573.
           DOI: https://doi.org/10.1073/pnas.88.24.11569
    .. [3] Kuhn A, Rotter S (2004). Neuronal integration of synaptic input in
           the fluctuation-driven regime. Journal of Neuroscience, 24(10):2345-2356.
           DOI: https://doi.org/10.1523/JNEUROSCI.3349-03.2004
    .. [4] NEST Simulator ``iaf_cond_alpha`` model documentation and C++ source:
           ``models/iaf_cond_alpha.h`` and ``models/iaf_cond_alpha.cpp``.

    See Also
    --------
    iaf_cond_exp : Conductance-based LIF with exponential synapses
    iaf_psc_alpha : Current-based LIF with alpha synapses
    iaf_psc_delta : Current-based LIF with delta synapses

    Examples
    --------
    Create a population of 100 conductance-based neurons:

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> neurons = bst.iaf_cond_alpha(
        ...     in_size=100,
        ...     V_th=-50. * u.mV,
        ...     tau_syn_ex=0.5 * u.ms,
        ...     tau_syn_in=2.0 * u.ms
        ... )
    """
    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -60. * u.mV,
        E_ex: ArrayLike = 0. * u.mV,
        E_in: ArrayLike = -85. * u.mV,
        g_L: ArrayLike = 16.6667 * u.nS,
        tau_syn_ex: ArrayLike = 0.2 * u.ms,
        tau_syn_in: ArrayLike = 2.0 * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        gsl_error_tol: ArrayLike = 1e-3,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        g_ex_initializer: Callable = braintools.init.Constant(0. * u.nS),
        g_in_initializer: Callable = braintools.init.Constant(0. * u.nS),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.E_ex = braintools.init.param(E_ex, self.varshape)
        self.E_in = braintools.init.param(E_in, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.gsl_error_tol = gsl_error_tol

        self.V_initializer = V_initializer
        self.g_ex_initializer = g_ex_initializer
        self.g_in_initializer = g_in_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return
        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize all state variables. Creates and initializes membrane potential, conductance states, refractory counters, integration step size, and optional refractory indicator. Parameters ---------- **kwargs : dict Additional keyword arguments (unused, for API compatibility). Notes ----- - ``V``, ``g_ex``, ``g_in`` are initialized using their respective initializers. - ``dg_ex``, ``dg_in`` (alpha auxiliary states) are initialized to zero. - ``last_spike_time`` is set to large negative value (-1e7 ms). - ``refractory_step_count`` starts at 0 (not in refractory period). - ``integration_step`` is initialized to the global timestep ``dt``. - ``I_stim`` buffer starts at 0 pA. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() V = braintools.init.param(self.V_initializer, self.varshape) g_ex = braintools.init.param(self.g_ex_initializer, self.varshape) g_in = braintools.init.param(self.g_in_initializer, self.varshape) zeros = u.math.zeros(self.varshape, dtype=V.dtype) * (u.nS / u.ms) self.dg_ex = brainstate.ShortTermState(zeros) self.dg_in = brainstate.ShortTermState(zeros) self.g_ex = brainstate.HiddenState(g_ex) self.g_in = brainstate.HiddenState(g_in) self.V = brainstate.HiddenState(V) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient. Applies the surrogate spike function to a normalized voltage to produce a continuous approximation of spike events suitable for gradient-based learning. Parameters ---------- V : ArrayLike, optional Membrane potential to evaluate. If None, uses current ``self.V.value``. Shape must match neuron population shape. Returns ------- ArrayLike Spike output in [0, 1], where values close to 1 indicate spike events. Shape matches input voltage shape. Notes ----- The voltage is normalized to :math:`(V - V_\mathrm{th}) / (V_\mathrm{th} - V_\mathrm{reset})` before applying the surrogate function. This makes the surrogate function operate in a standardized range regardless of absolute voltage values. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, dg_ex, g_ex, dg_in, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, i_stim -- mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 # During refractory: v_eff = V_reset. Otherwise: v_eff = min(V, V_th). v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_th)) i_syn_exc = state.g_ex * (v_eff - self.E_ex) i_syn_inh = state.g_in * (v_eff - self.E_in) i_leak = self.g_L * (v_eff - self.E_L) dV_raw = (-i_leak - i_syn_exc - i_syn_inh + self.I_e + extra.i_stim) / self.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) ddg_ex = -state.dg_ex / self.tau_syn_ex dg_ex_dt = state.dg_ex - state.g_ex / self.tau_syn_ex ddg_in = -state.dg_in / self.tau_syn_in dg_in_dt = state.dg_in - state.g_in / self.tau_syn_in return DotDict(V=dV, dg_ex=ddg_ex, g_ex=dg_ex_dt, dg_in=ddg_in, g_in=dg_in_dt) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, and refractory handling. Parameters ---------- state : DotDict Keys: V, dg_ex, g_ex, dg_in, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, i_stim. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ # Clamp voltage during refractory period. refr_accept = accept & (extra.r > 0) new_V = u.math.where(refr_accept, self.V_reset, state.V) # Spike detection: not refractory and V >= V_th. spike_now = accept & (extra.r <= 0) & (new_V >= self.V_th) spike_mask = extra.spike_mask | spike_now new_V = u.math.where(spike_now, self.V_reset, new_V) r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r) new_state = DotDict({**state, 'V': new_V}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r}) return new_state, new_extra
[docs] def update(self, x=0. * u.pA): r"""Advance neuron state by one simulation timestep. Integrates ODEs, handles refractory period and spike emission, applies synaptic conductance jumps, and buffers external current for next step. This method implements the full NEST update semantics. Parameters ---------- x : ArrayLike, optional External current input for this timestep (pA). Broadcasted to population shape. This input is buffered and applied in the *next* timestep (one-step delay) to match NEST ring-buffer semantics. Default: 0 pA. Returns ------- ArrayLike Differentiable spike output (values in [0, 1], shape matching population). Computed using surrogate gradient on pre-reset membrane potential. Notes ----- **Update order** (matching NEST): 1. **Integrate ODEs**: Use RKF45 to advance ``V``, ``dg_ex``, ``g_ex``, ``dg_in``, ``g_in`` over ``(t, t+dt]`` with ``I_stim`` from previous step. 2. **Refractory/spike handling**: - If in refractory period: clamp ``V`` to ``V_reset``, decrement counter. - Else if ``V >= V_th``: emit spike, reset ``V`` to ``V_reset``, set refractory counter. 3. **Apply synaptic inputs**: Add conductance jumps from incoming spikes (via ``add_delta_input``) to ``dg_ex`` / ``dg_in`` with alpha normalization. 4. **Buffer current input**: Store ``x`` into ``I_stim`` for next timestep. The surrogate spike is computed from the *pre-reset* voltage to allow gradient flow through spike events during training. **Failure modes**: If integration does not converge within ``_MAX_ITERS`` iterations, the final state may be inaccurate. Reduce global ``dt`` or check for extreme parameter values if this occurs. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV dg_ex = self.dg_ex.value # nS/ms g_ex = self.g_ex.value # nS dg_in = self.dg_in.value # nS/ms g_in = self.g_in.value # nS r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, dg_ex=dg_ex, g_ex=g_ex, dg_in=dg_in, g_in=g_in) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, i_stim=i_stim, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V, dg_ex, g_ex = ode_state.V, ode_state.dg_ex, ode_state.g_ex dg_in, g_in = ode_state.dg_in, ode_state.g_in spike_mask, r = extra.spike_mask, extra.r # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.g_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.g_in.value), label='w_in') pscon_ex = np.e / self.tau_syn_ex # 1/ms pscon_in = np.e / self.tau_syn_in # 1/ms # Apply synaptic spike inputs. dg_ex = dg_ex + pscon_ex * w_ex # nS/ms + 1/ms * nS = nS/ms dg_in = dg_in + pscon_in * w_in # nS/ms + 1/ms * nS = nS/ms # Write back state. self.V.value = V self.dg_ex.value = dg_ex self.g_ex.value = g_ex self.dg_in.value = dg_in self.g_in.value = g_in self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return u.math.asarray(spike_mask, dtype=dftype)