Source code for brainpy_state._nest.iaf_cond_beta

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'iaf_cond_beta',
]


class iaf_cond_beta(NESTNeuron):
    r"""NEST-compatible conductance-based leaky integrate-and-fire neuron with beta-shaped synaptic conductances.

    This model implements a conductance-based LIF neuron with beta-function (dual-exponential)
    synaptic conductances for both excitatory and inhibitory channels. It follows NEST's
    ``iaf_cond_beta`` implementation, including hard threshold crossing, absolute refractory
    period, and one-step delayed external current buffering.

    **1. Mathematical Model**

    The membrane voltage :math:`V_m` evolves according to:

    .. math::

       C_m \frac{dV_m}{dt} = -I_\mathrm{leak} - I_{\mathrm{syn,ex}} - I_{\mathrm{syn,in}} + I_e + I_\mathrm{stim}

    where the currents are defined as:

    .. math::

       I_\mathrm{leak} &= g_L (V_m - E_L) \\
       I_{\mathrm{syn,ex}} &= g_\mathrm{ex}(t) (V_m - E_\mathrm{ex}) \\
       I_{\mathrm{syn,in}} &= g_\mathrm{in}(t) (V_m - E_\mathrm{in})

    During the refractory period, the membrane voltage is clamped to :math:`V_\mathrm{reset}`
    and :math:`dV_m/dt = 0`. Outside the refractory period, the effective voltage for
    synaptic current computation is bounded by :math:`\min(V_m, V_\mathrm{th})`.

    **2. Beta-Function Conductance Dynamics**

    Each synaptic conductance (excitatory and inhibitory) is modeled using two coupled
    state variables to produce a beta-function (rise-decay) waveform:

    .. math::

       \frac{d\,dg_\mathrm{ex}}{dt} &= -\frac{dg_\mathrm{ex}}{\tau_{\mathrm{decay,ex}}} \\
       \frac{d g_\mathrm{ex}}{dt} &= dg_\mathrm{ex} - \frac{g_\mathrm{ex}}{\tau_{\mathrm{rise,ex}}}

    .. math::

       \frac{d\,dg_\mathrm{in}}{dt} &= -\frac{dg_\mathrm{in}}{\tau_{\mathrm{decay,in}}} \\
       \frac{d g_\mathrm{in}}{dt} &= dg_\mathrm{in} - \frac{g_\mathrm{in}}{\tau_{\mathrm{rise,in}}}

    Incoming spikes cause instantaneous jumps in :math:`dg_\mathrm{ex}` or :math:`dg_\mathrm{in}`.
    Positive weights target the excitatory channel; negative weights target the inhibitory channel.
    Each spike weight (in nS) is multiplied by the beta normalization factor
    :math:`\kappa(\tau_\mathrm{rise}, \tau_\mathrm{decay})` to ensure unit weight produces
    a 1 nS peak conductance.

    The normalization factor is computed as:

    .. math::

       \kappa = \frac{1/\tau_\mathrm{rise} - 1/\tau_\mathrm{decay}}{\exp(-t_\mathrm{peak}/\tau_\mathrm{decay}) - \exp(-t_\mathrm{peak}/\tau_\mathrm{rise})}

    where :math:`t_\mathrm{peak} = \frac{\tau_\mathrm{rise} \tau_\mathrm{decay}}{\tau_\mathrm{decay} - \tau_\mathrm{rise}} \ln\left(\frac{\tau_\mathrm{decay}}{\tau_\mathrm{rise}}\right)`.

    **3. Numerical Integration**

    ODEs are integrated using the Runge-Kutta-Fehlberg (RKF45) adaptive-step method with
    embedded error control. The integrator maintains a persistent step size estimate
    (``integration_step``) across simulation steps, adjusting it based on local truncation
    error to satisfy a fixed absolute tolerance (``gsl_error_tol``).

    **4. Update Order (NEST Semantics)**

    Each simulation step executes the following operations in order:

    1. Integrate all ODEs on the interval :math:`(t, t+dt]` using RKF45.
    2. Inside integration loop: apply refractory clamp and spike/reset.
    3. After loop: decrement refractory counter once.
    4. Apply incoming spike weights to :math:`dg_\mathrm{ex}` and :math:`dg_\mathrm{in}`.
    5. Store external current input ``x`` into the delayed buffer ``I_stim`` (affects next step).

    This matches NEST's ring-buffer semantics: external currents applied at time :math:`t`
    take effect at time :math:`t + dt`.

    **5. Design Constraints and Assumptions**

    - **Refractory clamping**: During refractory period, voltage is fixed at :math:`V_\mathrm{reset}`
      and no integration occurs. NEST uses this approach for consistency with exact spike times.
    - **Beta normalization edge case**: When :math:`\tau_\mathrm{rise} \approx \tau_\mathrm{decay}`,
      the normalization factor approaches :math:`e / \tau_\mathrm{decay}` to avoid division by zero.

    Parameters
    ----------
    in_size : Size
        Population shape, specified as an integer (1D), tuple of integers (multi-dimensional),
        or brainstate Size object. Determines the shape of all state variables and parameters.
    E_L : ArrayLike, optional
        Leak reversal potential. Default: ``-70 mV``. Broadcast to ``in_size`` if scalar.
        Must have units of voltage (mV).
    C_m : ArrayLike, optional
        Membrane capacitance. Default: ``250 pF``. Broadcast to ``in_size`` if scalar.
        Must be strictly positive. Determines voltage response timescale :math:`\tau_m = C_m / g_L`.
    t_ref : ArrayLike, optional
        Absolute refractory period duration. Default: ``2 ms``. Broadcast to ``in_size`` if scalar.
        Must be non-negative. Converted to discrete grid steps via :math:`\lceil t_\mathrm{ref} / dt \rceil`.
    V_th : ArrayLike, optional
        Spike threshold voltage. Default: ``-55 mV``. Broadcast to ``in_size`` if scalar.
        Must satisfy :math:`V_\mathrm{reset} < V_\mathrm{th}`.
    V_reset : ArrayLike, optional
        Post-spike reset voltage. Default: ``-60 mV``. Broadcast to ``in_size`` if scalar.
        Must be strictly less than ``V_th``. Neuron is clamped to this value during refractory period.
    E_ex : ArrayLike, optional
        Excitatory reversal potential. Default: ``0 mV``. Broadcast to ``in_size`` if scalar.
        Typically positive (depolarizing).
    E_in : ArrayLike, optional
        Inhibitory reversal potential. Default: ``-85 mV``. Broadcast to ``in_size`` if scalar.
        Typically more negative than :math:`E_L` (hyperpolarizing).
    g_L : ArrayLike, optional
        Leak conductance. Default: ``16.6667 nS`` (yields :math:`\tau_m = 15` ms with default :math:`C_m`).
        Broadcast to ``in_size`` if scalar. Must be strictly positive.
    tau_rise_ex : ArrayLike, optional
        Excitatory conductance rise time constant. Default: ``0.2 ms``. Broadcast to ``in_size`` if scalar.
        Must be strictly positive. Smaller values produce faster rise times.
    tau_decay_ex : ArrayLike, optional
        Excitatory conductance decay time constant. Default: ``0.2 ms``. Broadcast to ``in_size`` if scalar.
        Must be strictly positive. When equal to ``tau_rise_ex``, beta function degenerates to alpha function.
    tau_rise_in : ArrayLike, optional
        Inhibitory conductance rise time constant. Default: ``2.0 ms``. Broadcast to ``in_size`` if scalar.
        Must be strictly positive. Typically slower than excitatory rise for GABA receptors.
    tau_decay_in : ArrayLike, optional
        Inhibitory conductance decay time constant. Default: ``2.0 ms``. Broadcast to ``in_size`` if scalar.
        Must be strictly positive. Determines inhibitory synaptic integration window.
    I_e : ArrayLike, optional
        Constant external current. Default: ``0 pA``. Broadcast to ``in_size`` if scalar.
        Added to membrane current at every time step.
    gsl_error_tol : ArrayLike
        Unitless local RKF45 error tolerance, broadcastable and strictly positive.
    V_initializer : Callable, optional
        Initialization function for membrane voltage. Default: ``Constant(-70 mV)``.
        Called as ``V_initializer(varshape)`` during ``init_state()``.
    g_ex_initializer : Callable, optional
        Initialization function for excitatory conductance. Default: ``Constant(0 nS)``.
        Called as ``g_ex_initializer(varshape)`` during ``init_state()``.
    g_in_initializer : Callable, optional
        Initialization function for inhibitory conductance. Default: ``Constant(0 nS)``.
        Called as ``g_in_initializer(varshape)`` during ``init_state()``.
    spk_fun : Callable, optional
        Surrogate gradient function for spike generation. Default: ``ReluGrad()``.
        Applied to scaled voltage :math:`(V - V_\mathrm{th}) / (V_\mathrm{th} - V_\mathrm{reset})`
        to produce differentiable spike output for gradient-based learning.
    spk_reset : str, optional
        Spike reset mode. Default: ``'hard'`` (stop-gradient reset, matches NEST behavior).
        Alternative: ``'soft'`` (subtractive reset :math:`V \leftarrow V - V_\mathrm{th}`).
    ref_var : bool, optional
        If ``True``, create a boolean ``refractory`` state variable indicating refractory status.
        Default: ``False``. Useful for monitoring or conditional computations.
    name : str, optional
        Module name for debugging and visualization. Default: ``None`` (auto-generated).

    Parameter Mapping
    -----------------

    The following table maps constructor parameters to mathematical notation and NEST equivalents:

    ==================== ================== ======================================== ================================================
    **Parameter**        **Default**        **Math equivalent**                      **Description**
    ==================== ================== ======================================== ================================================
    ``in_size``          (required)         —                                        Population shape
    ``E_L``              -70 mV             :math:`E_\mathrm{L}`                     Leak reversal potential
    ``C_m``              250 pF             :math:`C_\mathrm{m}`                     Membrane capacitance
    ``t_ref``            2 ms               :math:`t_\mathrm{ref}`                   Absolute refractory duration
    ``V_th``             -55 mV             :math:`V_\mathrm{th}`                    Spike threshold
    ``V_reset``          -60 mV             :math:`V_\mathrm{reset}`                 Reset potential
    ``E_ex``             0 mV               :math:`E_\mathrm{ex}`                    Excitatory reversal potential
    ``E_in``             -85 mV             :math:`E_\mathrm{in}`                    Inhibitory reversal potential
    ``g_L``              16.6667 nS         :math:`g_\mathrm{L}`                     Leak conductance
    ``tau_rise_ex``      0.2 ms             :math:`\tau_{\mathrm{rise,ex}}`          Excitatory beta rise constant
    ``tau_decay_ex``     0.2 ms             :math:`\tau_{\mathrm{decay,ex}}`         Excitatory beta decay constant
    ``tau_rise_in``      2.0 ms             :math:`\tau_{\mathrm{rise,in}}`          Inhibitory beta rise constant
    ``tau_decay_in``     2.0 ms             :math:`\tau_{\mathrm{decay,in}}`         Inhibitory beta decay constant
    ``I_e``              0 pA               :math:`I_\mathrm{e}`                     Constant external current
    ``gsl_error_tol``    1e-6               —                                        RKF45 error tolerance
    ``V_initializer``    Constant(-70 mV)   —                                        Membrane initializer
    ``g_ex_initializer`` Constant(0 nS)     —                                        Excitatory conductance initializer
    ``g_in_initializer`` Constant(0 nS)     —                                        Inhibitory conductance initializer
    ``spk_fun``          ReluGrad()         —                                        Surrogate spike function
    ``spk_reset``        ``'hard'``         —                                        Reset mode (``'hard'`` matches NEST)
    ``ref_var``          ``False``          —                                        Expose boolean refractory indicator
    ==================== ================== ======================================== ================================================

    Raises
    ------
    ValueError
        If ``V_reset >= V_th`` (reset potential must be below threshold).
    ValueError
        If ``C_m <= 0`` (capacitance must be strictly positive).
    ValueError
        If ``t_ref < 0`` (refractory period cannot be negative).
    ValueError
        If any of ``tau_rise_ex``, ``tau_decay_ex``, ``tau_rise_in``, ``tau_decay_in`` are non-positive.

    Notes
    -----
    **State Variables**


    - ``V`` : brainstate.HiddenState
        Membrane potential :math:`V_m` with shape ``(*in_size,)`` and units mV.
    - ``dg_ex`` : brainstate.ShortTermState
        Excitatory beta auxiliary state (nS/ms).
    - ``g_ex`` : brainstate.HiddenState
        Excitatory synaptic conductance with units nS.
    - ``dg_in`` : brainstate.ShortTermState
        Inhibitory beta auxiliary state (nS/ms).
    - ``g_in`` : brainstate.HiddenState
        Inhibitory synaptic conductance with units nS.
    - ``refractory_step_count`` : brainstate.ShortTermState
        Remaining refractory grid steps (integer, dtype ``int32``). Zero when not refractory.
    - ``integration_step`` : brainstate.ShortTermState
        Persistent RKF45 internal step size with units ms. Adapted automatically for numerical stability.
    - ``I_stim`` : brainstate.ShortTermState
        One-step delayed external current buffer with units pA. Updated after ODE integration.
    - ``last_spike_time`` : brainstate.ShortTermState
        Time of last emitted spike (units ms). Set to :math:`t + dt` when spike occurs.
    - ``refractory`` : brainstate.ShortTermState (optional)
        Boolean refractory indicator. Only created if ``ref_var=True``.

    **Performance Considerations:**

    This model uses per-neuron scalar NumPy integration loops, which are significantly slower
    than vectorized JAX operations. For large populations, consider using ``iaf_cond_exp``
    or ``iaf_cond_alpha`` with vectorized exponential integrators. The RKF45 method is
    primarily intended for high-accuracy validation against NEST rather than production simulations.

    **NEST Compatibility:**

    This implementation matches NEST 3.9+ ``iaf_cond_beta`` semantics, including:

    - Beta normalization factor computation (exact formula match).
    - One-step delayed external current handling.
    - Refractory voltage clamping during integration.
    - Hard threshold crossing and immediate reset.

    Minor differences from NEST:

    - NEST uses GSL's RK integrator; this uses a pure-Python RKF45 implementation.
    - Numerical differences may appear at :math:`O(10^{-6})` due to floating-point rounding.

    Examples
    --------
    **Basic Usage:**

    .. code-block:: python

        >>> import brainpy.state as bst
        >>> import saiunit as u
        >>> import brainstate as bstate
        >>>
        >>> with bstate.environ.context(dt=0.1 * u.ms):
        ...     neuron = bst.iaf_cond_beta(10, V_th=-50*u.mV, V_reset=-65*u.mV)
        ...     neuron.init_all_states()
        ...     # Apply excitatory synaptic input (5 nS conductance jump)
        ...     neuron.add_delta_input('syn_input', 5.0 * u.nS)
        ...     spikes = neuron()
        ...     print(neuron.V.value[:3])  # Membrane voltages of first 3 neurons
        [-70. -70. -70.] mV

    **Comparing Excitatory and Inhibitory Time Constants:**

    .. code-block:: python

        >>> import matplotlib.pyplot as plt
        >>> with bstate.environ.context(dt=0.01 * u.ms):
        ...     fast_ex = bst.iaf_cond_beta(1, tau_rise_ex=0.2*u.ms, tau_decay_ex=2.0*u.ms)
        ...     slow_in = bst.iaf_cond_beta(1, tau_rise_in=2.0*u.ms, tau_decay_in=10.0*u.ms)
        ...     fast_ex.init_all_states()
        ...     slow_in.init_all_states()
        ...     # Single excitatory spike at t=1ms
        ...     fast_ex.add_delta_input('spike', 10.0 * u.nS)
        ...     # Record excitatory conductance
        ...     g_ex_trace = []
        ...     for _ in range(500):
        ...         fast_ex()
        ...         g_ex_trace.append(fast_ex.g_ex.value[0])
        ...     plt.plot(g_ex_trace)
        ...     plt.xlabel('Time (0.01 ms steps)')
        ...     plt.ylabel('g_ex (nS)')
        ...     plt.title('Beta-function conductance waveform')

    **Network with Balanced Excitation and Inhibition:**

    .. code-block:: python

        >>> from brainevent.nn import FixedProb
        >>> exc_neurons = bst.iaf_cond_beta(800, E_L=-70*u.mV, V_th=-50*u.mV)
        >>> inh_neurons = bst.iaf_cond_beta(200, E_L=-70*u.mV, V_th=-50*u.mV)
        >>> exc_neurons.init_all_states()
        >>> inh_neurons.init_all_states()
        >>> # Create projections (placeholder - requires brainevent)
        >>> # exc_proj = FixedProb(exc_neurons, exc_neurons, prob=0.1, weight=2.0*u.nS)
        >>> # inh_proj = FixedProb(inh_neurons, exc_neurons, prob=0.2, weight=-5.0*u.nS)

    See Also
    --------
    iaf_cond_alpha : LIF with alpha-function conductances (single time constant).
    iaf_cond_exp : LIF with exponential conductances (simpler, faster).
    iaf_psc_exp : Current-based LIF (no conductance dynamics).

    References
    ----------
    .. [1] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
           the large, fluctuating synaptic conductance state typical of
           neocortical neurons in vivo. Journal of Computational Neuroscience,
           16:159-175. DOI: https://doi.org/10.1023/B:JCNS.0000014108.03012.81
    .. [2] Bernander O, Douglas RJ, Martin KAC, Koch C (1991). Synaptic
           background activity influences spatiotemporal integration in single
           pyramidal cells. PNAS, 88(24):11569-11573.
           DOI: https://doi.org/10.1073/pnas.88.24.11569
    .. [3] Kuhn A, Rotter S (2004). Neuronal integration of synaptic input in
           the fluctuation-driven regime. Journal of Neuroscience, 24(10):2345-2356.
           DOI: https://doi.org/10.1523/JNEUROSCI.3349-03.2004
    .. [4] Rotter S, Diesmann M (1999). Exact simulation of time-invariant
           linear systems with applications to neuronal modeling.
           Biological Cybernetics, 81:381-402.
           DOI: https://doi.org/10.1007/s004220050570
    .. [5] Roth A, van Rossum M (2010). Chapter 6: Modeling synapses.
           In De Schutter, Computational Modeling Methods for Neuroscientists.
    """
    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000
    _EPS = np.finfo(np.float64).eps

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -70. * u.mV,
        C_m: ArrayLike = 250. * u.pF,
        t_ref: ArrayLike = 2. * u.ms,
        V_th: ArrayLike = -55. * u.mV,
        V_reset: ArrayLike = -60. * u.mV,
        E_ex: ArrayLike = 0. * u.mV,
        E_in: ArrayLike = -85. * u.mV,
        g_L: ArrayLike = 16.6667 * u.nS,
        tau_rise_ex: ArrayLike = 0.2 * u.ms,
        tau_decay_ex: ArrayLike = 0.2 * u.ms,
        tau_rise_in: ArrayLike = 2.0 * u.ms,
        tau_decay_in: ArrayLike = 2.0 * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        gsl_error_tol: ArrayLike = 1e-6,
        V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
        g_ex_initializer: Callable = braintools.init.Constant(0. * u.nS),
        g_in_initializer: Callable = braintools.init.Constant(0. * u.nS),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        ref_var: bool = False,
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.E_ex = braintools.init.param(E_ex, self.varshape)
        self.E_in = braintools.init.param(E_in, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.tau_rise_ex = braintools.init.param(tau_rise_ex, self.varshape)
        self.tau_decay_ex = braintools.init.param(tau_decay_ex, self.varshape)
        self.tau_rise_in = braintools.init.param(tau_rise_in, self.varshape)
        self.tau_decay_in = braintools.init.param(tau_decay_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.gsl_error_tol = gsl_error_tol

        self.V_initializer = V_initializer
        self.g_ex_initializer = g_ex_initializer
        self.g_in_initializer = g_in_initializer
        self.ref_var = ref_var

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    @classmethod
    def _beta_normalization_factor_scalar(cls, tau_rise: float, tau_decay: float):
        r"""Compute beta normalization factor for scalar time constants.

        Parameters
        ----------
        tau_rise : float
            Rise time constant in ms (unitless scalar).
        tau_decay : float
            Decay time constant in ms (unitless scalar).

        Returns
        -------
        float
            Normalization factor ensuring unit weight produces 1 nS peak conductance.
        """
        tau_difference = tau_decay - tau_rise
        peak_value = 0.0
        if abs(tau_difference) > cls._EPS:
            t_peak = tau_decay * tau_rise * np.log(tau_decay / tau_rise) / tau_difference
            peak_value = np.exp(-t_peak / tau_decay) - np.exp(-t_peak / tau_rise)

        if abs(peak_value) < cls._EPS:
            return np.e / tau_decay

        return (1.0 / tau_rise - 1.0 / tau_decay) / peak_value

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
            return
        if np.any(self.V_reset >= self.V_th):
            raise ValueError('Reset potential must be smaller than threshold.')
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.tau_rise_ex <= 0.0 * u.ms) or np.any(
            self.tau_decay_ex <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.tau_rise_in <= 0.0 * u.ms) or np.any(
            self.tau_decay_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize persistent and short-term state variables. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() g_ex = braintools.init.param(self.g_ex_initializer, self.varshape) g_in = braintools.init.param(self.g_in_initializer, self.varshape) V = braintools.init.param(self.V_initializer, self.varshape) zeros = u.math.zeros(self.varshape, dtype=V.dtype) * (u.nS / u.ms) self.dg_ex = brainstate.ShortTermState(zeros) self.dg_in = brainstate.ShortTermState(zeros) self.g_ex = brainstate.HiddenState(g_ex) self.g_in = brainstate.HiddenState(g_in) self.V = brainstate.HiddenState(V) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype)) if self.ref_var: refractory = braintools.init.param(braintools.init.Constant(False), self.varshape) self.refractory = brainstate.ShortTermState(refractory)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradients. Scales the membrane voltage relative to threshold and reset, then applies the surrogate spike function to produce a continuous spike signal suitable for gradient-based learning. Parameters ---------- V : ArrayLike, optional Membrane voltage to evaluate. If ``None`` (default), uses ``self.V.value``. Must have compatible shape with ``V_th`` and ``V_reset`` (broadcast-compatible). Expected units: mV (or dimensionless if consistent). Returns ------- spike : jax.Array Spike output with same shape as input ``V``. Values depend on ``spk_fun`` but are typically in :math:`[0, 1]` for surrogate gradient functions like ``ReluGrad``. Higher values indicate stronger spike activation. Dtype is ``float32``. Notes ----- The scaling formula is: .. math:: \mathrm{spike} = \mathrm{spk\_fun}\left(\frac{V - V_\mathrm{th}}{V_\mathrm{th} - V_\mathrm{reset}}\right) This normalization ensures that when :math:`V = V_\mathrm{th}`, the scaled input is zero, and when :math:`V = V_\mathrm{reset}`, the scaled input is :math:`-1`. The surrogate function (e.g., ``ReluGrad``) produces a differentiable approximation to the Heaviside step function for backpropagation. This method is called internally by ``update()`` to generate spike outputs, but can also be called manually for custom spike detection logic. """ V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, dg_ex, g_ex, dg_in, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim -- mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ is_refractory = extra.r > 0 v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_th)) i_syn_exc = state.g_ex * (v_eff - self.E_ex) i_syn_inh = state.g_in * (v_eff - self.E_in) i_leak = self.g_L * (v_eff - self.E_L) dV_raw = ( -i_leak - i_syn_exc - i_syn_inh + self.I_e + extra.i_stim ) / self.C_m dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw) ddg_ex = -state.dg_ex / self.tau_decay_ex dg_ex_dt = state.dg_ex - state.g_ex / self.tau_rise_ex ddg_in = -state.dg_in / self.tau_decay_in dg_in_dt = state.dg_in - state.g_in / self.tau_rise_in return DotDict(V=dV, dg_ex=ddg_ex, g_ex=dg_ex_dt, dg_in=ddg_in, g_in=dg_in_dt) def _event_fn(self, state, extra, accept): """In-loop spike detection, reset, and refractory handling. Parameters ---------- state : DotDict Keys: V, dg_ex, g_ex, dg_in, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/reset/refractory info. """ unstable = extra.unstable | jnp.any( accept & (state.V < -1e3 * u.mV) ) refr_accept = accept & (extra.r > 0) new_V = u.math.where(refr_accept, self.V_reset, state.V) spike_now = accept & (extra.r <= 0) & (new_V >= self.V_th) spike_mask = extra.spike_mask | spike_now new_V = u.math.where(spike_now, self.V_reset, new_V) r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r) new_state = DotDict({**state, 'V': new_V}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable}) return new_state, new_extra
[docs] def update(self, x=0. * u.pA): r"""Advance neuron dynamics by one simulation time step. Performs the full NEST-compatible update cycle: ODE integration via RKF45, refractory countdown, threshold detection, spike emission, reset, synaptic input application, and delayed external current buffering. Parameters ---------- x : ArrayLike, optional External current input for the **next** time step (one-step delayed). Default: ``0 pA``. Must have shape compatible with ``(*in_size,)`` (broadcast-compatible). Units: pA (picoamperes). This current is stored in ``I_stim`` and takes effect at time :math:`t + dt`, matching NEST's ring-buffer semantics. Returns ------- spike : jax.Array Binary spike output for the current time step. Shape: ``self.V.value.shape``. Dtype: ``float64``. Values of ``1.0`` indicate at least one internal spike event occurred during the integrated interval :math:`(t, t+dt]`. Notes ----- **Update Order (NEST-compatible):** 1. **ODE Integration**: Integrate all differential equations on :math:`(t, t+dt]` using the Runge-Kutta-Fehlberg (RKF45) adaptive-step method. 2. **Refractory Handling**: Inside integration loop, apply refractory clamp and spike/reset events. 3. **Refractory Decrement**: After loop, decrement refractory counter once. 4. **Synaptic Input Application**: Sum all incoming delta inputs (spike weights), split by sign into excitatory (positive) and inhibitory (negative) channels, multiply by beta normalization factors, and add to ``dg_ex`` and ``dg_in`` states. 5. **External Current Buffering**: Store input ``x`` plus ``sum_current_inputs()`` into ``I_stim`` for use in the **next** time step. **Spike Weight Handling:** - All delta inputs (registered via ``add_delta_input()``) are summed and split by sign. - Positive weights :math:`w_\mathrm{ex} = \max(w, 0)` are multiplied by :math:`\kappa(\tau_{\mathrm{rise,ex}}, \tau_{\mathrm{decay,ex}})` and added to ``dg_ex``. - Negative weights :math:`w_\mathrm{in} = \max(-w, 0)` are multiplied by :math:`\kappa(\tau_{\mathrm{rise,in}}, \tau_{\mathrm{decay,in}})` and added to ``dg_in``. - The beta normalization factor ensures unit weight produces a 1 nS peak conductance. See Also -------- init_state : Initialize state variables before calling ``update()``. get_spike : Compute spike output from membrane voltage. add_delta_input : Register synaptic spike inputs. sum_current_inputs : Aggregate external current sources. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV dg_ex = self.dg_ex.value # nS/ms g_ex = self.g_ex.value # nS dg_in = self.dg_in.value # nS/ms g_in = self.g_in.value # nS r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, dg_ex=dg_ex, g_ex=g_ex, dg_in=dg_in, g_in=g_in) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V, dg_ex, g_ex = ode_state.V, ode_state.dg_ex, ode_state.g_ex dg_in, g_in = ode_state.dg_in, ode_state.g_in spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in iaf_cond_beta dynamics.' ) # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.g_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.g_in.value), label='w_in') # Compute beta normalization factors. # Extract unitless tau values for the scalar beta normalization computation. tau_rise_ex_ms = float(u.get_mantissa(self.tau_rise_ex / u.ms)) if np.ndim(self.tau_rise_ex) == 0 else None tau_decay_ex_ms = float(u.get_mantissa(self.tau_decay_ex / u.ms)) if np.ndim(self.tau_decay_ex) == 0 else None tau_rise_in_ms = float(u.get_mantissa(self.tau_rise_in / u.ms)) if np.ndim(self.tau_rise_in) == 0 else None tau_decay_in_ms = float(u.get_mantissa(self.tau_decay_in / u.ms)) if np.ndim(self.tau_decay_in) == 0 else None if tau_rise_ex_ms is not None and tau_decay_ex_ms is not None: pscon_ex = self._beta_normalization_factor_scalar(tau_rise_ex_ms, tau_decay_ex_ms) / u.ms else: # Fallback: use element-wise computation for array taus pscon_ex = np.e / self.tau_decay_ex if tau_rise_in_ms is not None and tau_decay_in_ms is not None: pscon_in = self._beta_normalization_factor_scalar(tau_rise_in_ms, tau_decay_in_ms) / u.ms else: pscon_in = np.e / self.tau_decay_in # Apply synaptic spike inputs. dg_ex = dg_ex + pscon_ex * w_ex # nS/ms + 1/ms * nS = nS/ms dg_in = dg_in + pscon_in * w_in # nS/ms + 1/ms * nS = nS/ms # Write back state. self.V.value = V self.dg_ex.value = dg_ex self.g_ex.value = g_ex self.dg_in.value = dg_in self.g_in.value = g_in self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) if self.ref_var: self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0) return u.math.asarray(spike_mask, dtype=dftype)