Source code for brainpy_state._nest.hh_cond_exp_traub

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'hh_cond_exp_traub',
]


class hh_cond_exp_traub(NESTNeuron):
    r"""NEST-compatible ``hh_cond_exp_traub`` neuron model.

    Hodgkin-Huxley model for Brette et al. (2007) review, based on Traub and
    Miles (1991) hippocampal pyramidal cell model.

    This is a modified Hodgkin-Huxley neuron model specifically developed for
    the Brette et al. (2007) simulator review, based on a model of hippocampal
    pyramidal cells by Traub and Miles (1991). Key differences from the original
    Traub-Miles model:

    - This is a point neuron, not a compartmental model.
    - Only ``I_Na`` and ``I_K`` ionic currents are included (no calcium dynamics),
      with simplified ``I_K`` dynamics giving three gating variables instead of
      eight.
    - Incoming spikes induce an instantaneous conductance change followed by
      exponential decay (conductance-based synapses), not activation over time.

    Parameters
    ----------
    in_size : int, tuple of int
        Population shape (number of neurons or spatial dimensions).
    E_L : ArrayLike, default -60 mV
        Leak reversal potential. Must be finite.
    C_m : ArrayLike, default 200 pF
        Membrane capacitance. Must be strictly positive.
    g_Na : ArrayLike, default 20000 nS
        Sodium peak conductance. Must be non-negative.
    g_K : ArrayLike, default 6000 nS
        Potassium peak conductance. Must be non-negative.
    g_L : ArrayLike, default 10 nS
        Leak conductance. Must be non-negative.
    E_Na : ArrayLike, default 50 mV
        Sodium reversal potential. Must be finite.
    E_K : ArrayLike, default -90 mV
        Potassium reversal potential. Must be finite.
    V_T : ArrayLike, default -63 mV
        Voltage offset for gating dynamics. Shifts the effective threshold
        to approximately V_T + 30 mV.
    E_ex : ArrayLike, default 0 mV
        Excitatory synaptic reversal potential. Must be finite.
    E_in : ArrayLike, default -80 mV
        Inhibitory synaptic reversal potential. Must be finite.
    t_ref : ArrayLike, default 2 ms
        Duration of refractory period. Must be non-negative. Traub and Miles
        used 3 ms; NEST default is 2 ms.
    tau_syn_ex : ArrayLike, default 5 ms
        Excitatory synaptic time constant. Must be strictly positive.
    tau_syn_in : ArrayLike, default 10 ms
        Inhibitory synaptic time constant. Must be strictly positive.
    I_e : ArrayLike, default 0 pA
        Constant external input current. Can be positive or negative.
    V_m_init : ArrayLike, optional
        Initial membrane potential. If None, defaults to E_L.
    Act_m_init : ArrayLike, optional
        Initial sodium activation gating variable (0 <= m <= 1). If None,
        computed from equilibrium at V_m_init.
    Inact_h_init : ArrayLike, optional
        Initial sodium inactivation gating variable (0 <= h <= 1). If None,
        computed from equilibrium at V_m_init.
    Act_n_init : ArrayLike, optional
        Initial potassium activation gating variable (0 <= n <= 1). If None,
        computed from equilibrium at V_m_init.
    gsl_error_tol : ArrayLike
        Unitless local RKF45 error tolerance, broadcastable and strictly positive.
    spk_fun : Callable, default braintools.surrogate.ReluGrad()
        Surrogate spike function for differentiable spike generation.
    spk_reset : str, default 'hard'
        Reset mode ('hard' or 'soft'). Note: HH models do not reset voltage
        after spikes; this parameter affects gradient computation only.
    name : str, optional
        Name of the neuron population.

    Attributes
    ----------
    V : brainstate.HiddenState
        Membrane potential with shape (\*in_size,) in mV.
    m : brainstate.HiddenState
        Sodium activation gating variable (0 <= m <= 1), shape (\*in_size,).
    h : brainstate.HiddenState
        Sodium inactivation gating variable (0 <= h <= 1), shape (\*in_size,).
    n : brainstate.HiddenState
        Potassium activation gating variable (0 <= n <= 1), shape (\*in_size,).
    g_ex : brainstate.HiddenState
        Excitatory synaptic conductance in nS, shape (\*in_size,).
    g_in : brainstate.HiddenState
        Inhibitory synaptic conductance in nS, shape (\*in_size,).
    I_stim : brainstate.ShortTermState
        Stimulation current buffer in pA, shape (\*in_size,).
    refractory_step_count : brainstate.ShortTermState
        Refractory countdown in grid steps, shape (\*in_size,), dtype int32.
    integration_step : brainstate.ShortTermState
        Persistent RKF45 substep size estimate (ms).
    last_spike_time : brainstate.ShortTermState
        Time of most recent spike in ms, shape (\*in_size,).

    Raises
    ------
    ValueError
        If C_m <= 0, t_ref < 0, tau_syn_ex <= 0, or tau_syn_in <= 0.

    Notes
    -----
    - Unlike IAF models, the HH model does **not** reset the membrane
      potential after a spike. Repolarization occurs naturally through
      the potassium current.
    - During the refractory period, subthreshold dynamics continue to
      evolve freely; only spike emission is suppressed.
    - Synaptic spike weights are interpreted in conductance units (nS).
      Positive weights drive excitatory synapses; negative weights drive
      inhibitory synapses (sign is flipped, i.e. ``g_in += |w|``).
    - The numerical integration uses an adaptive RKF45 (Runge-Kutta-Fehlberg)
      integrator implemented in JAX with unit-aware arithmetic via saiunit.
      This is equivalent to NEST's GSL RKF45 implementation for numerical
      correspondence.

    Mathematical Formulation
    -------------------------

    **1. Membrane and Ionic Current Dynamics**

    The membrane potential evolves as:

    .. math::

       C_m \frac{dV_m}{dt} = -(I_{Na} + I_K + I_L + I_{syn,ex} + I_{syn,in})
                              + I_{stim} + I_e

    where the currents are:

    .. math::

       I_{Na}     &= g_{Na}\, m^3\, h\, (V_m - E_{Na})  \\
       I_K        &= g_K\,   n^4\,     (V_m - E_K)       \\
       I_L        &= g_L\,             (V_m - E_L)        \\
       I_{syn,ex} &= g_{ex}\,          (V_m - E_{ex})     \\
       I_{syn,in} &= g_{in}\,          (V_m - E_{in})

    **2. Channel Gating Variables**

    Gating variables :math:`m`, :math:`h`, :math:`n` obey first-order kinetics:

    .. math::

       \frac{dx}{dt} = \alpha_x(V)(1 - x) - \beta_x(V)\,x
                     = \alpha_x - (\alpha_x + \beta_x)\, x

    with Traub-Miles rate functions using shifted voltage
    :math:`V = V_m - V_T` (voltage in mV, rates in 1/ms):

    .. math::

       \alpha_n &= \frac{0.032\,(15 - V)}{e^{(15 - V)/5} - 1}, \quad
       \beta_n  = 0.5\,e^{(10 - V)/40}                                \\
       \alpha_m &= \frac{0.32\,(13 - V)}{e^{(13 - V)/4} - 1}, \quad
       \beta_m  = \frac{0.28\,(V - 40)}{e^{(V - 40)/5} - 1}          \\
       \alpha_h &= 0.128\,e^{(17 - V)/18}, \quad
       \beta_h  = \frac{4}{1 + e^{(40 - V)/5}}

    The voltage offset :math:`V_T` (default -63 mV) shifts the effective
    threshold to approximately -50 mV.

    **3. Exponential Conductance Synapses**

    Synaptic conductances decay exponentially:

    .. math::

       \frac{dg_{ex}}{dt} &= -g_{ex} / \tau_{syn,ex} \\
       \frac{dg_{in}}{dt} &= -g_{in} / \tau_{syn,in}

    A presynaptic spike with weight :math:`w` causes an instantaneous
    conductance jump:

    - :math:`w > 0` -- :math:`g_{ex} \leftarrow g_{ex} + w`
    - :math:`w < 0` -- :math:`g_{in} \leftarrow g_{in} + |w|`

    **4. Spike Detection**

    A spike is emitted when all three conditions are satisfied:

    1. ``r == 0`` (not in refractory period), **and**
    2. ``V_m >= V_T + 30`` mV (threshold crossing), **and**
    3. ``V_old > V_m`` (local maximum, the potential is now falling).

    Unlike integrate-and-fire models, no voltage reset occurs -- the
    potassium current naturally repolarizes the membrane.

    .. warning::

       To avoid multiple spikes during the falling flank of a spike, it is
       essential to choose a sufficiently long refractory period.
       Traub and Miles used :math:`t_{ref} = 3` ms, while the default here
       is :math:`t_{ref} = 2` ms (matching NEST).

    **5. Numerical Integration**

    NEST uses GSL RKF45 (Runge-Kutta-Fehlberg 4/5) with adaptive step-size
    control. This implementation uses an adaptive RKF45 integrator implemented
    in JAX with unit-aware arithmetic via saiunit, matching NEST's integration
    approach for numerical correspondence.

    The ODE system is 6-dimensional per neuron:
    :math:`[V_m, m, h, n, g_{ex}, g_{in}]`.

    Parameter Mapping
    -----------------

    The following table shows the correspondence between brainpy.state parameters
    and NEST/mathematical notation:

    ==================== ================== =============================== ====================================================
    **Parameter**        **Default**        **Math equivalent**             **Description**
    ==================== ================== =============================== ====================================================
    ``in_size``          (required)         --                              Population shape
    ``E_L``              -60 mV             :math:`E_L`                     Leak reversal potential
    ``C_m``              200 pF             :math:`C_m`                     Membrane capacitance
    ``g_Na``             20000 nS           :math:`g_{Na}`                  Sodium peak conductance
    ``g_K``              6000 nS            :math:`g_K`                     Potassium peak conductance
    ``g_L``              10 nS              :math:`g_L`                     Leak conductance
    ``E_Na``             50 mV              :math:`E_{Na}`                  Sodium reversal potential
    ``E_K``              -90 mV             :math:`E_K`                     Potassium reversal potential
    ``V_T``              -63 mV             :math:`V_T`                     Voltage offset for gating dynamics
    ``E_ex``             0 mV               :math:`E_{ex}`                  Excitatory synaptic reversal potential
    ``E_in``             -80 mV             :math:`E_{in}`                  Inhibitory synaptic reversal potential
    ``t_ref``            2 ms               :math:`t_{ref}`                 Duration of refractory period
    ``tau_syn_ex``       5 ms               :math:`\tau_{syn,ex}`           Excitatory synaptic time constant
    ``tau_syn_in``       10 ms              :math:`\tau_{syn,in}`           Inhibitory synaptic time constant
    ``I_e``              0 pA               :math:`I_e`                     Constant external input current
    ``V_m_init``         None               --                              Initial V_m (None -> E_L)
    ``Act_m_init``       None               --                              Initial Na activation (None -> equilibrium)
    ``Inact_h_init``     None               --                              Initial Na inactivation (None -> equilibrium)
    ``Act_n_init``       None               --                              Initial K activation (None -> equilibrium)
    ``gsl_error_tol``    1e-3               --                              Local RKF45 error tolerance
    ``spk_fun``          ReluGrad()         --                              Surrogate spike function
    ``spk_reset``        ``'hard'``         --                              Reset mode
    ==================== ================== =============================== ====================================================

    Examples
    --------
    .. code-block:: python

       >>> import brainstate as bst
       >>> import saiunit as u
       >>> from brainpy_state import hh_cond_exp_traub
       >>>
       >>> # Create a population of 100 Traub HH neurons
       >>> neurons = hh_cond_exp_traub(100)
       >>> neurons.init_all_states()
       >>>
       >>> # Run a simulation with constant current injection
       >>> with bst.environ.context(dt=0.1*u.ms):
       ...     for i in range(1000):
       ...         spikes = neurons.update(I_e=200*u.pA)

    .. code-block:: python

       >>> # Compare with NEST default parameters
       >>> import nest
       >>> nest_neuron = nest.Create('hh_cond_exp_traub')
       >>> nest.GetStatus(nest_neuron, ['V_m', 'E_L', 'C_m', 'g_Na', 'g_K'])
       [(-60.0, -60.0, 200.0, 20000.0, 6000.0)]
       >>>
       >>> # Match in brainpy.state
       >>> bp_neuron = hh_cond_exp_traub(1, E_L=-60*u.mV, C_m=200*u.pF,
       ...                               g_Na=20000*u.nS, g_K=6000*u.nS)

    References
    ----------
    .. [1] Brette R et al. (2007). Simulation of networks of spiking neurons:
           A review of tools and strategies. Journal of Computational
           Neuroscience 23:349-98.
           DOI: https://doi.org/10.1007/s10827-007-0038-6
    .. [2] Traub RD and Miles R (1991). Neuronal networks of the hippocampus.
           Cambridge University Press, Cambridge UK.
    .. [3] ModelDB entry: http://modeldb.yale.edu/83319

    See Also
    --------
    hh_psc_alpha : Hodgkin-Huxley with alpha-shaped postsynaptic currents.
    iaf_cond_exp : Leaky integrate-and-fire with conductance-based synapses.
    """

    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        E_L: ArrayLike = -60. * u.mV,
        C_m: ArrayLike = 200. * u.pF,
        g_Na: ArrayLike = 20000. * u.nS,
        g_K: ArrayLike = 6000. * u.nS,
        g_L: ArrayLike = 10. * u.nS,
        E_Na: ArrayLike = 50. * u.mV,
        E_K: ArrayLike = -90. * u.mV,
        V_T: ArrayLike = -63. * u.mV,
        E_ex: ArrayLike = 0. * u.mV,
        E_in: ArrayLike = -80. * u.mV,
        t_ref: ArrayLike = 2. * u.ms,
        tau_syn_ex: ArrayLike = 5. * u.ms,
        tau_syn_in: ArrayLike = 10. * u.ms,
        I_e: ArrayLike = 0. * u.pA,
        V_m_init: ArrayLike = None,
        Act_m_init: ArrayLike = None,
        Inact_h_init: ArrayLike = None,
        Act_n_init: ArrayLike = None,
        gsl_error_tol: ArrayLike = 1e-3,
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.g_Na = braintools.init.param(g_Na, self.varshape)
        self.g_K = braintools.init.param(g_K, self.varshape)
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.E_Na = braintools.init.param(E_Na, self.varshape)
        self.E_K = braintools.init.param(E_K, self.varshape)
        self.V_T = braintools.init.param(V_T, self.varshape)
        self.E_ex = braintools.init.param(E_ex, self.varshape)
        self.E_in = braintools.init.param(E_in, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
        self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)
        self.V_m_init = V_m_init
        self.Act_m_init = Act_m_init
        self.Inact_h_init = Inact_h_init
        self.Act_n_init = Act_n_init
        self.gsl_error_tol = gsl_error_tol

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # other variable
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate parameter constraints.

        Raises
        ------
        ValueError
            If capacitance C_m <= 0, refractory time t_ref < 0, or any synaptic
            time constant (tau_syn_ex, tau_syn_in) <= 0.

        Notes
        -----
        This is called during __init__ to ensure physical validity of parameters.
        Conductances (g_L, g_Na, g_K) are not validated for positivity since
        zero conductance is physically meaningful (though unusual).
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.C_m, self.t_ref, self.tau_syn_ex)):
            return
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
            raise ValueError('All time constants must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize all state variables for the neuron population. Initializes membrane potential, gating variables, synaptic conductances, stimulation current buffer, refractory counter, and last spike time. If initial values are not explicitly provided, they are computed as follows: - ``V``: defaults to ``E_L`` - ``m, h, n``: computed from equilibrium at initial ``V`` using Traub-Miles rate equations (without V_T offset, matching NEST initialization) - ``g_ex, g_in``: initialized to zero - ``I_stim``: initialized to zero - ``refractory_step_count``: initialized to zero (not refractory) - ``last_spike_time``: initialized to -1e7 ms (far in the past) Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- The equilibrium gating variable computation uses the raw voltage V (not V - V_T) to match NEST's initialization procedure. During dynamics, the rate equations use the shifted voltage V - V_T, but initialization uses the unshifted value for consistency with NEST's ``State_::State_`` constructor. This initialization ensures the neuron starts in a stable resting state when V_m_init = E_L (default). For custom initial voltages, gating variables are automatically adjusted to the corresponding equilibrium. Examples -------- .. code-block:: python >>> import brainstate as bst >>> import saiunit as u >>> from brainpy_state import hh_cond_exp_traub >>> >>> # Initialize with default rest state >>> neurons = hh_cond_exp_traub(100) >>> neurons.init_state() >>> print(neurons.V.value[0]) # Should be E_L = -60 mV -60.0 mV >>> >>> # Initialize with custom voltage >>> neurons = hh_cond_exp_traub(100, V_m_init=-65*u.mV) >>> neurons.init_state() >>> print(neurons.V.value[0]) -65.0 mV Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. See Also -------- _hh_cond_exp_traub_equilibrium : Computes equilibrium gating values. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() # Default V_m_init to E_L (matching NEST: y_[0] = p.E_L) if self.V_m_init is not None: V_init_val = self.V_m_init else: V_init_val = self.E_L V = braintools.init.param(braintools.init.Constant(V_init_val), self.varshape) # Compute equilibrium gating variables at initial V. # NEST uses raw V_m (not V_m - V_T) for equilibrium initialization. V_init_mV = float(np.asarray(u.math.asarray(V_init_val / u.mV)).flat[0]) m_eq, h_eq, n_eq = _hh_cond_exp_traub_equilibrium(V_init_mV) if self.Act_m_init is not None: m_init = float(np.asarray(u.math.asarray(self.Act_m_init / u.UNITLESS)).flat[0]) else: m_init = m_eq if self.Inact_h_init is not None: h_init = float(np.asarray(u.math.asarray(self.Inact_h_init / u.UNITLESS)).flat[0]) else: h_init = h_eq if self.Act_n_init is not None: n_init = float(np.asarray(u.math.asarray(self.Act_n_init / u.UNITLESS)).flat[0]) else: n_init = n_eq self.V = brainstate.HiddenState(V) self.m = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(m_init), self.varshape) ) self.h = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(h_init), self.varshape) ) self.n = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(n_init), self.varshape) ) self.g_ex = brainstate.HiddenState(u.math.zeros(self.varshape, dtype=V.dtype) * u.nS) self.g_in = brainstate.HiddenState(u.math.zeros(self.varshape, dtype=V.dtype) * u.nS) self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms)) self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype)) self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape) self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute differentiable spike output using surrogate gradient function. Applies the surrogate spike function to the membrane potential. This is used for gradient-based learning; actual spike detection in the update method uses discrete threshold crossing logic (V >= V_T + 30 and local maximum). Parameters ---------- V : ArrayLike, optional Membrane potential in mV, shape (\*in_size,) or (batch_size, \*in_size). If None, uses the current state ``self.V.value``. Returns ------- ArrayLike Differentiable spike output with the same shape as input V. Values are approximately 0 (no spike) or 1 (spike) with smooth gradients for backpropagation. Notes ----- The voltage is scaled to unitless values (mV) before applying the surrogate function. For Hodgkin-Huxley neurons, the actual spike threshold is V_T + 30 mV (default: -33 mV), but the surrogate function operates on the raw scaled voltage for gradient computation. This method is primarily used for surrogate gradient learning. The discrete spike detection logic in the update method is independent and uses the three-condition test (refractory, threshold, local maximum). Examples -------- .. code-block:: python >>> import saiunit as u >>> import jax.numpy as jnp >>> from brainpy_state import hh_cond_exp_traub >>> >>> neurons = hh_cond_exp_traub(10) >>> neurons.init_state() >>> >>> # Get spike output for current state >>> spikes = neurons.get_spike() >>> print(spikes.shape) (10,) >>> >>> # Get spike output for custom voltage >>> V_custom = jnp.array([-60., -50., -40.]) * u.mV >>> neurons_3 = hh_cond_exp_traub(3) >>> neurons_3.init_state() >>> spikes_custom = neurons_3.get_spike(V_custom) See Also -------- update : Main update method with discrete spike detection logic. """ V = self.V.value if V is None else V # For HH neurons with Traub threshold: spike at V_T + 30. # Scale relative to 0 mV for the surrogate function. v_scaled = V / (1. * u.mV) return self.spk_fun(v_scaled)
def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V, m, h, n, g_ex, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, V_old -- mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ V_m = state.V # Ionic currents I_Na = self.g_Na * state.m ** 3 * state.h * (V_m - self.E_Na) I_K = self.g_K * state.n ** 4 * (V_m - self.E_K) I_L = self.g_L * (V_m - self.E_L) # Synaptic currents (conductance-based) I_syn_exc = state.g_ex * (V_m - self.E_ex) I_syn_inh = state.g_in * (V_m - self.E_in) # Membrane voltage derivative dV = (-I_Na - I_K - I_L - I_syn_exc - I_syn_inh + extra.i_stim + self.I_e) / self.C_m # Shifted voltage for gating variable rate equations V_shifted = (V_m - self.V_T) / u.mV # unitless # Traub-Miles rate functions alpha_n = 0.032 * (15.0 - V_shifted) / (u.math.exp((15.0 - V_shifted) / 5.0) - 1.0) / u.ms beta_n = 0.5 * u.math.exp((10.0 - V_shifted) / 40.0) / u.ms alpha_m = 0.32 * (13.0 - V_shifted) / (u.math.exp((13.0 - V_shifted) / 4.0) - 1.0) / u.ms beta_m = 0.28 * (V_shifted - 40.0) / (u.math.exp((V_shifted - 40.0) / 5.0) - 1.0) / u.ms alpha_h = 0.128 * u.math.exp((17.0 - V_shifted) / 18.0) / u.ms beta_h = 4.0 / (1.0 + u.math.exp((40.0 - V_shifted) / 5.0)) / u.ms # Gating variable derivatives dm = alpha_m - (alpha_m + beta_m) * state.m dh = alpha_h - (alpha_h + beta_h) * state.h dn = alpha_n - (alpha_n + beta_n) * state.n # Synaptic conductance derivatives dg_ex = -state.g_ex / self.tau_syn_ex dg_in = -state.g_in / self.tau_syn_in return DotDict(V=dV, m=dm, h=dh, n=dn, g_ex=dg_ex, g_in=dg_in) def _event_fn(self, state, extra, accept): """In-loop spike detection and refractory handling. Detects spikes using threshold crossing and local maximum conditions, and manages refractory state. Unlike IAF models, no voltage reset is applied -- repolarization occurs naturally through potassium currents. Parameters ---------- state : DotDict Keys: V, m, h, n, g_ex, g_in -- ODE state variables. extra : DotDict Keys: spike_mask, r, unstable, i_stim, V_old. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts with updated spike/refractory info. """ unstable = extra.unstable | jnp.any( accept & ((state.V < -1e3 * u.mV) | (state.V > 1e3 * u.mV)) ) # Spike detection threshold: V_T + 30 mV v_threshold = self.V_T + 30.0 * u.mV # Spike conditions: not refractory, threshold crossed, and local maximum (V_old > V) not_refractory = extra.r <= 0 crossed_threshold = state.V >= v_threshold local_max = extra.V_old > state.V spike_now = accept & not_refractory & crossed_threshold & local_max spike_mask = extra.spike_mask | spike_now # Update V_old to track the previous voltage for local-max detection new_V_old = u.math.where(accept, state.V, extra.V_old) # Set refractory counter on spike r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count, extra.r) new_state = DotDict({**state}) new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable, 'V_old': new_V_old}) return new_state, new_extra
[docs] def update(self, x=0. * u.pA): r"""Update neuron state for one simulation step. Integrates the 6-dimensional ODE system for one time step using adaptive RKF45 solver, processes incoming synaptic inputs, detects spikes based on threshold crossing and local maximum, and updates refractory state. The update follows the NEST ``hh_cond_exp_traub`` update order: 1. Record pre-integration membrane potential (``V_old``). 2. Integrate the full 6-dimensional ODE system over one time step using an adaptive RKF45 solver. 3. Add arriving synaptic conductance jumps to ``g_ex`` / ``g_in``. 4. Check spike condition: ``V_m >= V_T + 30 and V_old > V_m`` (threshold + local maximum). 5. Update refractory counter and record spike time. 6. Store buffered stimulation current for the next step. Parameters ---------- x : ArrayLike, default 0 pA External stimulation current input (in addition to ``I_e``), shape () or (\*in_size,). This current is added to the constant ``I_e`` parameter and any registered current inputs via ``add_current_input()``. Returns ------- ArrayLike Spike output with shape (\*in_size,). Values are computed using the surrogate spike function for differentiability. Spikes occur only when the discrete spike condition is satisfied (not refractory, threshold crossed, and local maximum detected). Notes ----- **Integration Details:** Each neuron's state is integrated using an adaptive RKF45 integrator implemented in JAX with unit-aware arithmetic. This matches NEST's GSL RKF45 solver. The ODE system is: .. math:: \frac{d}{dt}\begin{bmatrix} V_m \\ m \\ h \\ n \\ g_{ex} \\ g_{in} \end{bmatrix} = \begin{bmatrix} (-I_{Na} - I_K - I_L - I_{syn,ex} - I_{syn,in} + I_{stim} + I_e) / C_m \\ \alpha_m - (\alpha_m + \beta_m) m \\ \alpha_h - (\alpha_h + \beta_h) h \\ \alpha_n - (\alpha_n + \beta_n) n \\ -g_{ex} / \tau_{syn,ex} \\ -g_{in} / \tau_{syn,in} \end{bmatrix} **Spike Detection Logic:** A spike is detected when all three conditions are met: 1. ``refractory_step_count == 0`` (not in refractory period) 2. ``V_m >= V_T + 30`` (threshold crossing) 3. ``V_old > V_m`` (local maximum - voltage falling) No voltage reset occurs; repolarization is handled by intrinsic currents. **Synaptic Input Processing:** Delta inputs (spike events) are collected and split by sign: - Positive weights -> excitatory conductance (``g_ex += w``) - Negative weights -> inhibitory conductance (``g_in += |w|``) Conductance jumps are applied **after** ODE integration, matching NEST's update sequence. **Computational Complexity** Integration is performed with an adaptive vectorized RKF45 loop, including in-loop spike detection and refractory handling. All arithmetic is unit-aware via ``saiunit.math``. **Failure Modes** - If the integrator detects numerical instability (``V < -1e3 mV`` or ``V > 1e3 mV``), a runtime error is raised. - Extreme parameter values (very large conductances, very small time constants) may cause numerical instability. See Also -------- init_state : Initialize state variables. get_spike : Compute surrogate spike output. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Read state variables with their natural units. V = self.V.value # mV m = self.m.value # unitless h_val = self.h.value # unitless n = self.n.value # unitless g_ex = self.g_ex.value # nS g_in = self.g_in.value # nS r = self.refractory_step_count.value # int i_stim = self.I_stim.value # pA h_step = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim = self.sum_current_inputs(x, self.V.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict(V=V, m=m, h=h_val, n=n, g_ex=g_ex, g_in=g_in) extra = DotDict( spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_), r=r, unstable=jnp.array(False), i_stim=i_stim, V_old=V, # Track previous V for local-max spike detection ) ode_state, h_step, extra = self.integrator(state=ode_state, h=h_step, extra=extra) V, m, h_val = ode_state.V, ode_state.m, ode_state.h n, g_ex, g_in = ode_state.n, ode_state.g_ex, ode_state.g_in spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable # Post-loop stability check. brainstate.transform.jit_error_if( jnp.any(unstable), 'Numerical instability in hh_cond_exp_traub dynamics.' ) # Decrement refractory counter. r = u.math.where(r > 0, r - 1, r) # Synaptic spike inputs (applied after integration). w_ex = self.sum_delta_inputs(u.math.zeros_like(self.g_ex.value), label='w_ex') w_in = self.sum_delta_inputs(u.math.zeros_like(self.g_in.value), label='w_in') # Apply synaptic spike inputs (instantaneous conductance jump). g_ex = g_ex + w_ex g_in = g_in + w_in # Write back state. self.V.value = V self.m.value = m self.h.value = h_val self.n.value = n self.g_ex.value = g_ex self.g_in.value = g_in self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype) self.integration_step.value = h_step self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) return u.math.asarray(spike_mask, dtype=dftype)
def _hh_cond_exp_traub_equilibrium(V): r"""Compute Traub HH gating variable equilibrium values at voltage V (mV). This matches NEST's ``State_::State_(const Parameters_&)`` initialization, which applies the Traub rate equations **without** the V_T offset. The dynamics function uses ``V - V_T`` in its rate equations, but the equilibrium initialization in NEST uses the raw voltage ``y_[0]`` (= E_L). Parameters ---------- V : float Membrane potential in mV (unitless scalar). Must be a finite value. Returns ------- m_inf : float Sodium activation equilibrium value (0 <= m_inf <= 1). h_inf : float Sodium inactivation equilibrium value (0 <= h_inf <= 1). n_inf : float Potassium activation equilibrium value (0 <= n_inf <= 1). Notes ----- The function evaluates the steady-state gating variables using Traub-Miles rate equations at the unshifted voltage V (not V - V_T). This is the correct initialization procedure for NEST compatibility, where initial gating states are computed at the raw E_L value, not shifted by V_T. The alpha/beta rate functions may encounter division by zero at special voltage values (e.g., V = 15 mV for alpha_n). These singularities are removable via L'Hospital's rule but may cause numerical issues if V is exactly at these points. Mathematical Formulation ------------------------ Equilibrium values are computed from the Traub-Miles rate equations: .. math:: x_{\infty}(V) = \frac{\alpha_x(V)}{\alpha_x(V) + \beta_x(V)} **1. Potassium Activation (n)** .. math:: \alpha_n &= \frac{0.032(15 - V)}{e^{(15-V)/5} - 1} \\ \beta_n &= 0.5 \, e^{(10-V)/40} **2. Sodium Activation (m)** .. math:: \alpha_m &= \frac{0.32(13 - V)}{e^{(13-V)/4} - 1} \\ \beta_m &= \frac{0.28(V - 40)}{e^{(V-40)/5} - 1} **3. Sodium Inactivation (h)** .. math:: \alpha_h &= 0.128 \, e^{(17-V)/18} \\ \beta_h &= \frac{4}{1 + e^{(40-V)/5}} Examples -------- .. code-block:: python >>> from brainpy_state._nest.hh_cond_exp_traub import _hh_cond_exp_traub_equilibrium >>> m_inf, h_inf, n_inf = _hh_cond_exp_traub_equilibrium(-60.0) >>> print(f"m={m_inf:.4f}, h={h_inf:.4f}, n={n_inf:.4f}") m=0.0529, h=0.5961, n=0.3177 See Also -------- hh_cond_exp_traub : The neuron model class that uses these equilibrium values. """ import math alpha_n = 0.032 * (15.0 - V) / (math.exp((15.0 - V) / 5.0) - 1.0) beta_n = 0.5 * math.exp((10.0 - V) / 40.0) alpha_m = 0.32 * (13.0 - V) / (math.exp((13.0 - V) / 4.0) - 1.0) beta_m = 0.28 * (V - 40.0) / (math.exp((V - 40.0) / 5.0) - 1.0) alpha_h = 0.128 * math.exp((17.0 - V) / 18.0) beta_h = 4.0 / (1.0 + math.exp((40.0 - V) / 5.0)) m_inf = alpha_m / (alpha_m + beta_m) h_inf = alpha_h / (alpha_h + beta_h) n_inf = alpha_n / (alpha_n + beta_n) return m_inf, h_inf, n_inf