Source code for brainpy_state._nest.iaf_cond_exp

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'iaf_cond_exp',
]


class iaf_cond_exp(NESTNeuron):
    r"""Leaky integrate-and-fire model with exponential conductance synapses.

    This is a conductance-based leaky integrate-and-fire neuron with hard threshold,
    fixed absolute refractory period, exponentially decaying excitatory and inhibitory
    synaptic conductances, and no adaptation variables.

    This implementation follows NEST ``iaf_cond_exp`` dynamics and update order,
    using NEST C++ model behavior as the source of truth.

    **1. Membrane Potential and Synaptic Conductances**

    The membrane potential evolves according to

    .. math::

       \frac{dV_\mathrm{m}}{dt} =
       \frac{-g_\mathrm{L}(V_\mathrm{m}-E_\mathrm{L})
             - I_\mathrm{syn}
             + I_\mathrm{e}
             + I_\mathrm{stim}}
            {C_\mathrm{m}}

    with the total synaptic current given by

    .. math::

       I_\mathrm{syn}
       = I_{\mathrm{syn,ex}} + I_{\mathrm{syn,in}}
       = g_\mathrm{ex}(V_\mathrm{m}-E_\mathrm{ex})
       + g_\mathrm{in}(V_\mathrm{m}-E_\mathrm{in}) .

    Synaptic conductances decay exponentially:

    .. math::

       \frac{dg_\mathrm{ex}}{dt} = -\frac{g_\mathrm{ex}}{\tau_{\mathrm{syn,ex}}},
       \qquad
       \frac{dg_\mathrm{in}}{dt} = -\frac{g_\mathrm{in}}{\tau_{\mathrm{syn,in}}}.

    A presynaptic spike with weight :math:`w` causes an instantaneous jump at
    the end of the simulation step:

    .. math::

       w > 0 \Rightarrow g_\mathrm{ex} \leftarrow g_\mathrm{ex} + w,
       \qquad
       w < 0 \Rightarrow g_\mathrm{in} \leftarrow g_\mathrm{in} + |w|.

    **2. Spike Emission and Refractory Mechanism**

    A spike is emitted when :math:`V_\mathrm{m} \ge V_\mathrm{th}` at the end of
    a simulation step. On spike:

    * :math:`V_\mathrm{m}` is reset to :math:`V_\mathrm{reset}`,
    * refractory counter is set to :math:`\lceil t_\mathrm{ref}/dt \rceil`,
    * spike time is recorded as :math:`t + dt`.

    During absolute refractory period:

    * membrane potential is clamped to :math:`V_\mathrm{reset}`,
    * :math:`dV_\mathrm{m}/dt = 0`,
    * conductances continue to decay.

    **3. Numerical Integration and Update Order**

    NEST integrates this model with adaptive RKF45. This implementation mirrors
    that behavior with an RKF45(4,5) integrator and persistent internal step size.
    The discrete-time update order is:

    1. Integrate continuous dynamics on :math:`(t, t+dt]`.
    2. Add synaptic conductance jumps from spike inputs arriving this step.
    3. Apply refractory countdown / threshold check / reset and spike emission.
    4. Store external current input as :math:`I_\mathrm{stim}` for the next step.

    The one-step delayed application of current input (``I_stim`` buffer) is
    intentional and matches NEST's ring-buffer update semantics.

    Parameters
    ----------
    in_size : int, tuple of int
        Shape of the neuron population. Can be an integer for 1D population or
        tuple for multi-dimensional populations.
    E_L : float, ArrayLike, optional
        Leak reversal potential. Must have unit of voltage (mV).
        Default: -70 mV
    C_m : float, ArrayLike, optional
        Membrane capacitance. Must be strictly positive with unit of capacitance (pF).
        Default: 250 pF
    t_ref : float, ArrayLike, optional
        Absolute refractory period duration. Must be non-negative with unit of time (ms).
        Default: 2 ms
    V_th : float, ArrayLike, optional
        Spike threshold voltage. Must be greater than ``V_reset`` with unit of voltage (mV).
        Default: -55 mV
    V_reset : float, ArrayLike, optional
        Reset potential after spike. Must be less than ``V_th`` with unit of voltage (mV).
        Default: -60 mV
    E_ex : float, ArrayLike, optional
        Excitatory reversal potential. Must have unit of voltage (mV).
        Default: 0 mV
    E_in : float, ArrayLike, optional
        Inhibitory reversal potential. Must have unit of voltage (mV).
        Default: -85 mV
    g_L : float, ArrayLike, optional
        Leak conductance. Must be strictly positive with unit of conductance (nS).
        Default: 16.6667 nS
    tau_syn_ex : float, ArrayLike, optional
        Excitatory synaptic conductance time constant. Must be strictly positive
        with unit of time (ms). Default: 0.2 ms
    tau_syn_in : float, ArrayLike, optional
        Inhibitory synaptic conductance time constant. Must be strictly positive
        with unit of time (ms). Default: 2.0 ms
    I_e : float, ArrayLike, optional
        Constant external input current. Must have unit of current (pA).
        Default: 0 pA
    gsl_error_tol : ArrayLike
        Unitless local RKF45 error tolerance, broadcastable and strictly positive.
    V_initializer : Callable, optional
        Initializer function for membrane potential state. Must return values with
        voltage units. Default: ``braintools.init.Constant(-70 * u.mV)``
    g_ex_initializer : Callable, optional
        Initializer function for excitatory conductance state. Must return values
        with conductance units. Default: ``braintools.init.Constant(0 * u.nS)``
    g_in_initializer : Callable, optional
        Initializer function for inhibitory conductance state. Must return values
        with conductance units. Default: ``braintools.init.Constant(0 * u.nS)``
    spk_fun : Callable, optional
        Surrogate gradient function for differentiable spike generation. Must be
        a callable with signature ``(x: ArrayLike) -> ArrayLike``.
        Default: ``braintools.surrogate.ReluGrad()``
    spk_reset : str, optional
        Spike reset mode. Options: ``'hard'`` (gradient blocking, matches NEST),
        ``'soft'`` (gradient-friendly subtraction). Default: ``'hard'``
    ref_var : bool, optional
        If True, expose a boolean ``refractory`` state variable indicating whether
        each neuron is currently in the refractory period. Default: False
    name : str, optional
        Name of the neuron population. If None, an automatic name is generated.


    Parameter Mapping
    -----------------

    ==================== ================== ==========================================
    **Parameter**        **Default**        **Math equivalent**
    ==================== ================== ==========================================
    ``in_size``          (required)         —
    ``E_L``              -70 mV             :math:`E_\mathrm{L}`
    ``C_m``              250 pF             :math:`C_\mathrm{m}`
    ``t_ref``            2 ms               :math:`t_\mathrm{ref}`
    ``V_th``             -55 mV             :math:`V_\mathrm{th}`
    ``V_reset``          -60 mV             :math:`V_\mathrm{reset}`
    ``E_ex``             0 mV               :math:`E_\mathrm{ex}`
    ``E_in``             -85 mV             :math:`E_\mathrm{in}`
    ``g_L``              16.6667 nS         :math:`g_\mathrm{L}`
    ``tau_syn_ex``       0.2 ms             :math:`\tau_{\mathrm{syn,ex}}`
    ``tau_syn_in``       2.0 ms             :math:`\tau_{\mathrm{syn,in}}`
    ``I_e``              0 pA               :math:`I_\mathrm{e}`
    ``gsl_error_tol``    1e-3               —
    ``V_initializer``    Constant(-70 mV)   —
    ``g_ex_initializer`` Constant(0 nS)     —
    ``g_in_initializer`` Constant(0 nS)     —
    ``spk_fun``          ReluGrad()         —
    ``spk_reset``        ``'hard'``         —
    ``ref_var``          ``False``          —
    ==================== ================== ==========================================

    State Variables
    ---------------
    V : brainstate.HiddenState
        Membrane potential :math:`V_\mathrm{m}` in mV, shape ``(*in_size)``.
    g_ex : brainstate.HiddenState
        Excitatory synaptic conductance :math:`g_\mathrm{ex}` in nS,
        shape ``(*in_size)``.
    g_in : brainstate.HiddenState
        Inhibitory synaptic conductance :math:`g_\mathrm{in}` in nS,
        shape ``(*in_size)``.
    last_spike_time : brainstate.ShortTermState
        Last spike emission time in ms, shape ``(*in_size)``.
    refractory_step_count : brainstate.ShortTermState
        Remaining refractory time steps (int32), shape ``(*in_size)``.
    integration_step : brainstate.ShortTermState
        Internal RKF45 adaptive step size in ms, shape ``(*in_size)``.
    I_stim : brainstate.ShortTermState
        Buffered external current (one-step delayed) in pA, shape ``(*in_size)``.
    refractory : brainstate.ShortTermState, optional
        Boolean refractory state indicator, shape ``(*in_size)``.
        Only present if ``ref_var=True``.

    Raises
    ------
    ValueError
        If ``V_reset >= V_th`` (reset must be below threshold).
    ValueError
        If ``C_m <= 0`` (capacitance must be strictly positive).
    ValueError
        If ``t_ref < 0`` (refractory time cannot be negative).
    ValueError
        If ``tau_syn_ex <= 0`` or ``tau_syn_in <= 0`` (time constants must be positive).
    ValueError
        If ``gsl_error_tol <= 0`` (error tolerance must be strictly positive).

    Notes
    -----
    * Defaults follow NEST C++ source for ``iaf_cond_exp``.
    * In NEST docs, some printed default values may differ from the source for
      specific releases; source code behavior is used here for parity.
    * Synaptic spike weights are interpreted in conductance units (nS), with
      positive/negative sign selecting excitatory/inhibitory channel.
    * The RKF45 integrator uses absolute error tolerance of 1e-3 with minimum
      step size of 1e-8 ms and maximum iteration count of 10000 per simulation step.
    * Integration may fall back to minimum step size if adaptive control fails,
      potentially degrading accuracy for stiff dynamics.

    Examples
    --------
    Create a population of 100 conductance-based LIF neurons:

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> neurons = bst.iaf_cond_exp(100, V_th=-50*u.mV, t_ref=5*u.ms)

    Simulate with external current input:

    .. code-block:: python

        >>> with bst.environ.context(dt=0.1*u.ms):
        ...     neurons.init_all_states()
        ...     for t in range(1000):
        ...         spike = neurons.update(x=500*u.pA)

    References
    ----------
    .. [1] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
           the large, fluctuating synaptic conductance state typical of
           neocortical neurons in vivo. Journal of Computational Neuroscience
           16:159-175. DOI: https://doi.org/10.1023/B:JCNS.0000014108.03012.81
    .. [2] NEST Simulator ``iaf_cond_exp`` model documentation and C++ source:
           ``models/iaf_cond_exp.h`` and ``models/iaf_cond_exp.cpp``.

    See Also
    --------
    iaf_psc_delta : Current-based LIF with delta synapses
    iaf_psc_exp : Current-based LIF with exponential synapses
    iaf_cond_alpha : Conductance-based LIF with alpha-function synapses
    """
    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 10000

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -60. * u.mV,
        E_ex: ArrayLike = 0. * u.mV,
        E_in: ArrayLike = -85. * u.mV,
        g_L: ArrayLike = 16.6667 * u.nS,
        tau_syn_ex: ArrayLike = 0.2 * u.ms,
        tau_syn_in: ArrayLike = 2.0 * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        gsl_error_tol: ArrayLike = 1e-3,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        g_ex_initializer: Callable = braintools.init.Constant(0. * u.nS),
        g_in_initializer: Callable = braintools.init.Constant(0. * u.nS),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.E_ex = braintools.init.param(E_ex, self.varshape)
        self.E_in = braintools.init.param(E_in, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.gsl_error_tol = gsl_error_tol

        self.V_initializer = V_initializer
        self.g_ex_initializer = g_ex_initializer
        self.g_in_initializer = g_in_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return
        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All synaptic time constants must be strictly positive.')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize all state variables for the neuron population. Creates and registers state variables for membrane potential, synaptic conductances, refractory tracking, RKF45 integration, and buffered currents. All states are initialized using the configured initializer functions. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- * State variables are registered as ``brainstate.HiddenState`` (continuous dynamics) or ``brainstate.ShortTermState`` (discrete/reset behavior). * ``last_spike_time`` is initialized to -1e7 ms (far past) to indicate no prior spikes. * ``integration_step`` is initialized to the simulation timestep ``dt``. * If ``ref_var=True``, an additional boolean ``refractory`` state is created. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() V = braintools.init.param(self.V_initializer, self.varshape) g_ex = braintools.init.param(self.g_ex_initializer, self.varshape) g_in = braintools.init.param(self.g_in_initializer, self.varshape) self.V = brainstate.HiddenState(V) self.g_ex = brainstate.HiddenState(g_ex) self.g_in = brainstate.HiddenState(g_in) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient. Transforms membrane potential into a continuous spike signal suitable for gradient-based learning. Uses the configured surrogate gradient function (``spk_fun``) applied to normalized voltage distance from threshold. Parameters ---------- V : ArrayLike, optional Membrane potential values to evaluate (with voltage units). If None, uses current ``self.V.value``. Default: None Returns ------- ArrayLike Spike signal with same shape as input ``V``. Values are continuous (not binary) to support gradient flow. Typically near 0 below threshold and near 1 above threshold, with smooth transition determined by ``spk_fun``. Notes ----- * Voltage is normalized as ``(V - V_th) / (V_th - V_reset)`` before applying the surrogate function. * The normalization ensures consistent surrogate behavior across different threshold/reset voltage configurations. * This method is used internally by ``update`` but can also be called externally for spike extraction. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, g_ex, g_in — ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim — mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_th)) i_syn_exc = state.g_ex * (v_eff - self.E_ex) i_syn_inh = state.g_in * (v_eff - self.E_in) i_leak = self.g_L * (v_eff - self.E_L) dV_raw = (-i_leak - i_syn_exc - i_syn_inh + self.I_e + extra.i_stim) / self.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) dg_ex = -state.g_ex / self.tau_syn_ex dg_in = -state.g_in / self.tau_syn_in return DotDict(V=dV, g_ex=dg_ex, g_in=dg_in) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, and refractory handling. Parameters ---------- state : DotDict Keys: V, g_ex, g_in — ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ unstable = extra.unstable | jnp.any( accept & (state.V < -1e3 * u.mV) ) refr_accept = accept & (extra.r > 0) new_V = u.math.where(refr_accept, self.V_reset, state.V) spike_now = accept & (extra.r <= 0) & (new_V >= self.V_th) spike_mask = extra.spike_mask | spike_now new_V = u.math.where(spike_now, self.V_reset, new_V) r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r) new_state = DotDict({**state, 'V': new_V}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable}) return new_state, new_extra
[docs] def update(self, x=0. * u.pA): r"""Advance neuron dynamics by one simulation timestep. Integrates membrane potential and synaptic conductances using adaptive RKF45, applies synaptic input increments, handles spike emission and reset, and stores external current for the next step. This method implements the complete NEST update cycle. Parameters ---------- x : ArrayLike, optional External current input for the **next** timestep (one-step delay buffer). Must have current units (pA). Can be scalar (broadcast to all neurons) or array with shape matching ``in_size``. Default: 0 pA Returns ------- ArrayLike Binary spike tensor with dtype ``jnp.float64`` and shape ``self.V.value.shape``. A value of ``1.0`` indicates at least one internal spike event occurred during the integrated interval :math:`(t, t+dt]`. Raises ------ ValueError If RKF45 integration enters a guarded unstable regime (``V < -1e3 mV``), indicating divergent dynamics for the current parameter/input regime. Notes ----- Integration is performed with an adaptive vectorized RKF45 loop, including in-loop spike/reset events and optional multiple spikes per step. All arithmetic is unit-aware via ``saiunit.math``. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV g_ex = self.g_ex.value # nS g_in = self.g_in.value # nS r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, g_ex=g_ex, g_in=g_in) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V, g_ex, g_in = ode_state.V, ode_state.g_ex, ode_state.g_in spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in iaf_cond_exp dynamics.' ) # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.g_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.g_in.value), label='w_in') # Apply synaptic spike inputs (direct conductance jump for exponential synapses). g_ex = g_ex + w_ex g_in = g_in + w_in # Write back state. self.V.value = V self.g_ex.value = g_ex self.g_in.value = g_in self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return u.math.asarray(spike_mask, dtype=dftype)