# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
r"""NEST-compatible ``ht_neuron`` model (Hill & Tononi, 2005).
This module implements the neuron model described in:
Hill S, Tononi G (2005). Modeling sleep and wakefulness in the
thalamocortical system. Journal of Neurophysiology, 93:1671-1698.
DOI: https://doi.org/10.1152/jn.00915.2004
The implementation follows the NEST ``models/ht_neuron.{h,cpp}`` source
exactly, including:
- Integrate-and-fire with adaptive (dynamic) threshold.
- Repolarizing potassium current instead of hard reset.
- AMPA, NMDA, GABA_A, and GABA_B conductance-based synapses with
beta-function (difference of exponentials) time course.
- Voltage-dependent NMDA with instantaneous or two-stage unblocking.
- Intrinsic currents I_h, I_T, I_Na(p), and I_KNa.
- Adaptive RKF45 ODE integration via AdaptiveRungeKuttaStep.
"""
from typing import Callable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep
__all__ = [
'ht_neuron',
]
# ---------------------------------------------------------------------------
# Equilibrium / steady-state helper functions (module-level, pure NumPy)
# ---------------------------------------------------------------------------
def _m_eq_h(V):
r"""Compute equilibrium activation for I_h hyperpolarization-activated current.
Implements the steady-state activation function for the h-current, which activates
with hyperpolarization and provides a depolarizing inward current that contributes
to resonance and rebound excitation in thalamocortical neurons.
Parameters
----------
V : float
Membrane potential in mV. Typical physiological range: -90 to +30 mV.
Returns
-------
float
Equilibrium activation value m_∞^Ih ∈ [0, 1]. Approaches 1 at hyperpolarized
potentials (V < -75 mV), and 0 at depolarized potentials.
Notes
-----
The activation follows a Boltzmann sigmoid with inflection at V = -75 mV and
slope factor 5.5 mV:
.. math::
m_\infty^{I_h}(V) = \frac{1}{1 + \exp\left(\frac{V + 75}{5.5}\right)}
This function is used to initialize the m_Ih state variable and is not voltage-
dependent during simulation (actual dynamics are governed by tau_m_h).
"""
I_h_Vthreshold = -75.0
return 1.0 / (1.0 + np.exp((V - I_h_Vthreshold) / 5.5))
def _h_eq_T(V):
r"""Compute equilibrium inactivation for I_T low-threshold calcium current.
Calculates the steady-state inactivation gate for the T-type Ca²⁺ channel,
which is responsible for burst firing and oscillatory behavior in thalamic neurons.
Inactivation is voltage-dependent and de-inactivates at hyperpolarized potentials.
Parameters
----------
V : float
Membrane potential in mV. Typical physiological range: -90 to +30 mV.
Returns
-------
float
Equilibrium inactivation value h_∞^IT ∈ [0, 1]. Approaches 1 at hyperpolarized
potentials (V < -83 mV) where the channel is deinactivated, and 0 at
depolarized potentials where the channel is inactivated.
Notes
-----
The inactivation follows a Boltzmann sigmoid with inflection at V = -83 mV and
slope factor 4.0 mV:
.. math::
h_\infty^{I_T}(V) = \frac{1}{1 + \exp\left(\frac{V + 83}{4}\right)}
This steep voltage dependence ensures that T-channels recover from inactivation
only after sufficient hyperpolarization, enabling rebound burst firing.
"""
return 1.0 / (1.0 + np.exp((V + 83.0) / 4.0))
def _m_eq_T(V):
r"""Compute equilibrium activation for I_T low-threshold calcium current.
Calculates the steady-state activation gate for the T-type Ca²⁺ channel. This
channel activates at relatively hyperpolarized potentials (hence "low-threshold")
and mediates burst firing when combined with de-inactivation.
Parameters
----------
V : float
Membrane potential in mV. Typical physiological range: -90 to +30 mV.
Returns
-------
float
Equilibrium activation value m_∞^IT ∈ [0, 1]. Approaches 1 at depolarized
potentials (V > -59 mV), and 0 at hyperpolarized potentials.
Notes
-----
The activation follows a Boltzmann sigmoid with inflection at V = -59 mV and
slope factor 6.2 mV:
.. math::
m_\infty^{I_T}(V) = \frac{1}{1 + \exp\left(-\frac{V + 59}{6.2}\right)}
In the full I_T current, this activation is raised to the power N_T (typically 2)
and multiplied by the inactivation variable h_IT, giving the current a transient
character essential for burst generation.
"""
return 1.0 / (1.0 + np.exp(-(V + 59.0) / 6.2))
def _D_eq_KNa(V, tau_D_KNa):
r"""Compute steady-state D value for I_KNa depolarization-activated potassium current.
The D variable represents an internal concentration-like quantity that accumulates
during sustained depolarization and drives the slow activation of I_KNa. This
provides spike-frequency adaptation on a longer timescale than typical AHP currents.
Parameters
----------
V : float
Membrane potential in mV. Typical physiological range: -90 to +30 mV.
tau_D_KNa : float
Relaxation time constant in ms. Controls the rate at which D approaches its
steady-state value. Typical value: 1250 ms (slow adaptation).
Returns
-------
float
Equilibrium D value (dimensionless, positive). At rest (~-70 mV), D ≈ 0.001;
at depolarized potentials (>-10 mV), D can reach ~0.03.
Notes
-----
The steady-state D is computed from a voltage-dependent influx term:
.. math::
D_{influx}(V) &= \frac{0.025}{1 + \exp\left(-\frac{V + 10}{5}\right)} \\
D_\infty(V) &= \tau_{D,KNa} \cdot D_{influx}(V) + 0.001
The influx is a sigmoid centered at V = -10 mV with slope 5 mV, multiplied by
tau_D_KNa to yield the equilibrium value. The additive constant 0.001 ensures
a minimum baseline D value even at hyperpolarized potentials.
This equilibrium function is used only for initialization; the full dynamics
during simulation include time-dependent relaxation toward D_∞.
"""
D_influx_peak = 0.025
D_thresh = -10.0
D_slope = 5.0
D_eq = 0.001
D_influx = D_influx_peak / (1.0 + np.exp(-(V - D_thresh) / D_slope))
return tau_D_KNa * D_influx + D_eq
def _m_eq_NMDA(V, S_act_NMDA, V_act_NMDA):
r"""Compute steady-state magnesium unblock ratio for NMDA receptor channels.
NMDA receptors are blocked by extracellular Mg²⁺ at hyperpolarized potentials
and unblock with depolarization, providing voltage-dependent gain and enabling
coincidence detection of pre- and post-synaptic activity.
Parameters
----------
V : float
Membrane potential in mV. Typical physiological range: -90 to +30 mV.
S_act_NMDA : float
Slope parameter for the NMDA unblocking sigmoid in 1/mV. Default: 0.081 mV⁻¹.
Higher values produce steeper voltage dependence.
V_act_NMDA : float
Voltage at inflection point of the NMDA unblocking sigmoid in mV. Default:
-25.57 mV. This is the potential at which 50% of channels are unblocked.
Returns
-------
float
Equilibrium Mg²⁺ unblock fraction m_∞^NMDA ∈ [0, 1]. At V = V_act_NMDA,
m_∞ = 0.5. Approaches 1 at depolarized potentials (full unblock) and 0 at
hyperpolarized potentials (full block).
Notes
-----
The unblock fraction follows a Boltzmann sigmoid:
.. math::
m_\infty^{NMDA}(V) = \frac{1}{1 + \exp\left(-S_{act} \cdot (V - V_{act})\right)}
When ``instant_unblock_NMDA`` is True, this equilibrium value is used directly
for the NMDA conductance calculation. When False, the model uses two-stage
kinetics with fast and slow unblocking time constants (tau_Mg_fast_NMDA,
tau_Mg_slow_NMDA) as described in Vargas-Caballero & Robinson (2003).
"""
return 1.0 / (1.0 + np.exp(-S_act_NMDA * (V - V_act_NMDA)))
def _beta_normalization_factor(tau_rise, tau_decay):
r"""Compute normalization constant for beta-function (difference-of-exponentials) synapse.
The beta function describes a synaptic conductance that rises and decays with two
different time constants. This normalization factor ensures that a unit synaptic
input produces a peak conductance of exactly g_peak, independent of the specific
tau_rise and tau_decay values.
This implementation matches NEST's ``beta_normalization_factor()`` from
``libnestutil/beta_normalization_factor.h``.
Parameters
----------
tau_rise : float
Synaptic rise time constant in ms. Must be positive and less than tau_decay
for proper beta-function behavior.
tau_decay : float
Synaptic decay time constant in ms. Must be positive and greater than tau_rise.
Returns
-------
float
Normalization constant (positive, unitless). Multiply this by g_peak and the
synaptic spike count to get the conductance step added to the DG variable.
Notes
-----
**1. Mathematical Derivation**
The unnormalized beta-function conductance kernel is:
.. math::
g(t) = \exp(-t/\tau_{decay}) - \exp(-t/\tau_{rise})
The peak occurs at time:
.. math::
t_{peak} = \frac{\tau_{rise} \cdot \tau_{decay}}{\tau_{decay} - \tau_{rise}}
\ln\left(\frac{\tau_{decay}}{\tau_{rise}}\right)
Evaluating g(t_peak) gives the peak amplitude. The normalization factor is:
.. math::
\text{norm} = \frac{1/\tau_{rise} - 1/\tau_{decay}}{g(t_{peak})}
**2. Alpha-Function Limit**
When tau_rise → tau_decay, the beta function becomes an alpha function:
.. math::
g(t) = \frac{e \cdot t}{\tau} \exp(-t/\tau)
with normalization factor e / tau_decay.
**3. Numerical Stability**
The function uses machine epsilon to detect near-equality of time constants and
avoid division by zero, ensuring stable computation across all parameter regimes.
**4. Usage in ht_neuron**
For each synapse type (AMPA, NMDA, GABA_A, GABA_B), the normalization factor is
precomputed during initialization and stored as _cond_step_*. When a spike arrives,
the DG variable is incremented by:
.. math::
\Delta DG = g_{peak} \cdot \text{norm} \cdot N_{spikes}
where N_spikes is the weighted spike count delivered to that receptor type.
"""
eps = np.finfo(np.float64).eps
tau_difference = tau_decay - tau_rise
peak_value = 0.0
if abs(tau_difference) > eps:
t_peak = tau_decay * tau_rise * np.log(tau_decay / tau_rise) / tau_difference
peak_value = np.exp(-t_peak / tau_decay) - np.exp(-t_peak / tau_rise)
if abs(peak_value) < eps:
# alpha-function limit
return np.e / tau_decay
else:
return (1.0 / tau_rise - 1.0 / tau_decay) / peak_value
class ht_neuron(NESTNeuron):
r"""NEST-compatible Hill-Tononi thalamocortical neuron model with intrinsic currents.
Implements the conductance-based integrate-and-fire neuron model from Hill & Tononi
(2005) designed to simulate sleep-wake dynamics in thalamocortical networks. Features
adaptive threshold, repolarizing post-spike potassium current, four receptor types
(AMPA, NMDA, GABA_A, GABA_B), and four intrinsic currents (I_NaP, I_KNa, I_T, I_h)
that mediate burst firing, adaptation, and oscillatory behavior.
This implementation replicates NEST's ``ht_neuron`` (models/ht_neuron.{h,cpp})
using JAX-compatible adaptive ODE integration with AdaptiveRungeKuttaStep.
Parameters
----------
in_size : int or tuple of int
Population shape (e.g., 100 or (10, 10)). Determines the number of neurons
in this layer.
E_Na : float, default=30.0
Sodium reversal potential in mV. Sets the depolarized reset level after spike.
E_K : float, default=-90.0
Potassium reversal potential in mV. Sets the hyperpolarized target for
repolarization during refractory period.
g_NaL : float, default=0.2
Sodium leak conductance (unitless). Contributes depolarizing leak current.
g_KL : float, default=1.0
Potassium leak conductance (unitless). Contributes hyperpolarizing leak current.
tau_m : float, default=16.0
Membrane time constant in ms. Governs the rate of subthreshold voltage changes.
theta_eq : float, default=-51.0
Equilibrium spike threshold in mV. The threshold relaxes to this value with
time constant tau_theta.
tau_theta : float, default=2.0
Threshold relaxation time constant in ms. Controls adaptation timescale.
tau_spike : float, default=1.75
Repolarization time constant for post-spike potassium current in ms. Governs
the speed of voltage recovery during refractory period.
t_ref : float, default=2.0
Absolute refractory period in ms. Duration of post-spike potassium current.
g_peak_AMPA : float, default=0.1
Peak AMPA conductance (unitless). Scaled by spike inputs to produce excitatory
synaptic current.
tau_rise_AMPA : float, default=0.5
AMPA conductance rise time constant in ms. Must be < tau_decay_AMPA.
tau_decay_AMPA : float, default=2.4
AMPA conductance decay time constant in ms.
E_rev_AMPA : float, default=0.0
AMPA reversal potential in mV.
g_peak_NMDA : float, default=0.075
Peak NMDA conductance (unitless). Subject to voltage-dependent Mg²⁺ block.
tau_rise_NMDA : float, default=4.0
NMDA conductance rise time constant in ms. Must be < tau_decay_NMDA.
tau_decay_NMDA : float, default=40.0
NMDA conductance decay time constant in ms.
E_rev_NMDA : float, default=0.0
NMDA reversal potential in mV.
V_act_NMDA : float, default=-25.57
Voltage at 50% NMDA Mg²⁺ unblock in mV.
S_act_NMDA : float, default=0.081
Slope of NMDA Mg²⁺ unblocking sigmoid in 1/mV.
tau_Mg_slow_NMDA : float, default=22.7
Slow Mg²⁺ unblocking time constant in ms. Must be > tau_Mg_fast_NMDA.
tau_Mg_fast_NMDA : float, default=0.68
Fast Mg²⁺ unblocking time constant in ms.
instant_unblock_NMDA : bool, default=False
If True, use instantaneous Mg²⁺ unblocking (m^NMDA = m_∞). If False, use
two-stage kinetics with fast and slow unblocking components.
g_peak_GABA_A : float, default=0.33
Peak GABA_A conductance (unitless). Fast inhibitory synaptic current.
tau_rise_GABA_A : float, default=1.0
GABA_A rise time constant in ms. Must be < tau_decay_GABA_A.
tau_decay_GABA_A : float, default=7.0
GABA_A decay time constant in ms.
E_rev_GABA_A : float, default=-70.0
GABA_A reversal potential in mV.
g_peak_GABA_B : float, default=0.0132
Peak GABA_B conductance (unitless). Slow inhibitory synaptic current.
tau_rise_GABA_B : float, default=60.0
GABA_B rise time constant in ms. Must be < tau_decay_GABA_B.
tau_decay_GABA_B : float, default=200.0
GABA_B decay time constant in ms.
E_rev_GABA_B : float, default=-90.0
GABA_B reversal potential in mV.
g_peak_NaP : float, default=1.0
Peak persistent sodium current conductance (unitless). Mediates subthreshold
depolarization and bistability.
E_rev_NaP : float, default=30.0
I_NaP reversal potential in mV.
N_NaP : float, default=3.0
I_NaP activation exponent (power to which m_∞ is raised).
g_peak_KNa : float, default=1.0
Peak I_KNa conductance (unitless). Provides slow spike-frequency adaptation.
E_rev_KNa : float, default=-90.0
I_KNa reversal potential in mV.
tau_D_KNa : float, default=1250.0
I_KNa D-variable relaxation time constant in ms. Large value produces very
slow adaptation (~seconds).
g_peak_T : float, default=1.0
Peak low-threshold Ca²⁺ current conductance (unitless). Mediates rebound
bursts and oscillations.
E_rev_T : float, default=0.0
I_T reversal potential in mV.
N_T : float, default=2.0
I_T activation exponent (power to which m_IT is raised).
g_peak_h : float, default=1.0
Peak hyperpolarization-activated current conductance (unitless). Contributes
to rebound excitation and resonance.
E_rev_h : float, default=-40.0
I_h reversal potential in mV.
voltage_clamp : bool, default=False
If True, clamp membrane potential at its initial value throughout simulation.
Used for testing intrinsic current dynamics in isolation.
gsl_error_tol : float, default=1e-3
Absolute error tolerance for the adaptive RKF45 integrator.
spk_fun : Callable, default=braintools.surrogate.ReluGrad()
Surrogate gradient function for differentiable spike generation.
spk_reset : str, default='hard'
Spike reset mode: 'hard' (stop gradient) or 'soft' (V -= V_th).
name : str or None, default=None
Name of this neuron population.
Parameter Mapping
-----------------
The table below maps brainpy.state parameter names to NEST equivalents:
========================= ====================== ======== ==============
brainpy.state Parameter NEST Parameter Default Units
========================= ====================== ======== ==============
``in_size`` (N/A) --- ---
``E_Na`` ``E_Na`` 30.0 mV
``E_K`` ``E_K`` -90.0 mV
``g_NaL`` ``g_NaL`` 0.2 (unitless)
``g_KL`` ``g_KL`` 1.0 (unitless)
``tau_m`` ``tau_m`` 16.0 ms
``theta_eq`` ``theta_eq`` -51.0 mV
``tau_theta`` ``tau_theta`` 2.0 ms
``tau_spike`` ``tau_spike`` 1.75 ms
``t_ref`` ``t_ref`` 2.0 ms
``g_peak_AMPA`` ``g_peak_AMPA`` 0.1 (unitless)
``tau_rise_AMPA`` ``tau_rise_AMPA`` 0.5 ms
``tau_decay_AMPA`` ``tau_decay_AMPA`` 2.4 ms
``E_rev_AMPA`` ``E_rev_AMPA`` 0.0 mV
``g_peak_NMDA`` ``g_peak_NMDA`` 0.075 (unitless)
``tau_rise_NMDA`` ``tau_rise_NMDA`` 4.0 ms
``tau_decay_NMDA`` ``tau_decay_NMDA`` 40.0 ms
``E_rev_NMDA`` ``E_rev_NMDA`` 0.0 mV
``V_act_NMDA`` ``V_act_NMDA`` -25.57 mV
``S_act_NMDA`` ``S_act_NMDA`` 0.081 1/mV
``tau_Mg_slow_NMDA`` ``tau_Mg_slow_NMDA`` 22.7 ms
``tau_Mg_fast_NMDA`` ``tau_Mg_fast_NMDA`` 0.68 ms
``instant_unblock_NMDA`` ``instant_unblock`` False ---
``g_peak_GABA_A`` ``g_peak_GABA_A`` 0.33 (unitless)
``tau_rise_GABA_A`` ``tau_rise_GABA_A`` 1.0 ms
``tau_decay_GABA_A`` ``tau_decay_GABA_A`` 7.0 ms
``E_rev_GABA_A`` ``E_rev_GABA_A`` -70.0 mV
``g_peak_GABA_B`` ``g_peak_GABA_B`` 0.0132 (unitless)
``tau_rise_GABA_B`` ``tau_rise_GABA_B`` 60.0 ms
``tau_decay_GABA_B`` ``tau_decay_GABA_B`` 200.0 ms
``E_rev_GABA_B`` ``E_rev_GABA_B`` -90.0 mV
``g_peak_NaP`` ``g_peak_NaP`` 1.0 (unitless)
``E_rev_NaP`` ``E_rev_NaP`` 30.0 mV
``N_NaP`` ``NaP_N`` 3.0 ---
``g_peak_KNa`` ``g_peak_KNa`` 1.0 (unitless)
``E_rev_KNa`` ``E_rev_KNa`` -90.0 mV
``tau_D_KNa`` ``tau_D_KNa`` 1250.0 ms
``g_peak_T`` ``g_peak_T`` 1.0 (unitless)
``E_rev_T`` ``E_rev_T`` 0.0 mV
``N_T`` ``T_N`` 2.0 ---
``g_peak_h`` ``g_peak_h`` 1.0 (unitless)
``E_rev_h`` ``E_rev_h`` -40.0 mV
``voltage_clamp`` ``voltage_clamp`` False ---
``gsl_error_tol`` (GSL tolerance) 1e-3 ---
========================= ====================== ======== ==============
Notes
-----
**1. Model Architecture**
The ht_neuron is an integrate-and-fire model with:
- **Adaptive threshold**: Threshold increases transiently after spike, then relaxes
to theta_eq, providing spike-frequency adaptation on ~ms timescale.
- **Soft reset**: No hard voltage reset. Instead, V and theta are set to E_Na, and
a repolarizing K⁺ current drives voltage back toward E_K during t_ref.
- **Four synaptic receptor types**: AMPA (fast excitation), NMDA (slow excitation
with voltage-dependent Mg²⁺ block), GABA_A (fast inhibition), GABA_B (slow
inhibition). Each uses beta-function (biexponential) conductance time course.
- **Four intrinsic currents**:
* **I_NaP** (persistent Na⁺): Subthreshold depolarizing current; enables
bistability and up-states.
* **I_KNa** (depolarization-activated K⁺): Very slow adaptation (~1 s timescale).
* **I_T** (low-threshold Ca²⁺): Mediates rebound bursts; deinactivates during
hyperpolarization and activates rapidly on depolarization.
* **I_h** (hyperpolarization-activated cation current): Sag current; contributes
to rebound and resonance.
**2. Membrane Dynamics**
The membrane potential obeys:
.. math::
\frac{dV}{dt} = \frac{I_{leak} + I_{syn} + I_{intrinsic} + I_{stim}}{\tau_m} + I_{spike}
where:
.. math::
I_{leak} &= -g_{NaL}(V - E_{Na}) - g_{KL}(V - E_K) \\
I_{syn} &= -g_{AMPA}(V - E_{AMPA}) - g_{NMDA} m^{NMDA}(V - E_{NMDA}) \\
&\quad - g_{GABA_A}(V - E_{GABA_A}) - g_{GABA_B}(V - E_{GABA_B}) \\
I_{intrinsic} &= I_{NaP} + I_{KNa} + I_T + I_h \\
I_{spike} &= \begin{cases}
-(V - E_K) / \tau_{spike} & \text{if refractory} \\
0 & \text{otherwise}
\end{cases}
**3. Dynamic Threshold**
.. math::
\frac{d\theta}{dt} = -\frac{\theta - \theta_{eq}}{\tau_\theta}
On spike, theta is reset to E_Na (along with V), then decays back to theta_eq.
This provides fast spike-frequency adaptation.
**4. Beta-Function Synapses**
Each synapse type uses a two-variable beta function (difference of exponentials):
.. math::
\frac{dg'}{dt} &= -\frac{g'}{\tau_{rise}} \\
\frac{dg}{dt} &= g' - \frac{g}{\tau_{decay}}
On arrival of a spike, the DG variable (g') is incremented by:
.. math::
\Delta g' = g_{peak} \cdot \text{norm}(\tau_{rise}, \tau_{decay}) \cdot w
where norm is the beta normalization factor and w is the synaptic weight.
**5. NMDA Voltage Dependence**
NMDA channels are blocked by Mg²⁺ at hyperpolarized potentials and unblock with
depolarization. Two modes are available:
- **Instantaneous unblocking** (instant_unblock_NMDA=True):
.. math::
m^{NMDA} = \frac{1}{1 + \exp(-S_{act}(V - V_{act}))}
- **Two-stage kinetics** (instant_unblock_NMDA=False):
.. math::
\frac{dm_{fast}}{dt} &= \frac{m_\infty - m_{fast}}{\tau_{Mg,fast}} \\
\frac{dm_{slow}}{dt} &= \frac{m_\infty - m_{slow}}{\tau_{Mg,slow}} \\
m^{NMDA} &= A_1(V) m_{fast} + A_2(V) m_{slow}
where A₁(V) = 0.51 - 0.0028V and A₂ = 1 - A₁. This captures the experimentally
observed slow Mg²⁺ unblocking kinetics (Vargas-Caballero & Robinson, 2003).
**6. Intrinsic Currents**
- **I_NaP** (persistent sodium):
.. math::
m_\infty &= \frac{1}{1 + \exp(-(V + 55.7)/7.7)} \\
I_{NaP} &= -g_{NaP} \cdot (m_\infty)^{N_{NaP}} \cdot (V - E_{NaP})
No inactivation; provides tonic depolarizing drive.
- **I_KNa** (depolarization-activated potassium):
.. math::
D_{influx} &= \frac{0.025}{1 + \exp(-(V + 10)/5)} \\
\frac{dD}{dt} &= \frac{\tau_{D,KNa} \cdot D_{influx} + 0.001 - D}{\tau_{D,KNa}} \\
m_\infty &= \frac{1}{1 + (0.25/D)^{3.5}} \\
I_{KNa} &= -g_{KNa} \cdot m_\infty \cdot (V - E_{KNa})
D accumulates slowly during depolarization; provides adaptation on ~second timescale.
- **I_T** (low-threshold Ca²⁺):
.. math::
m_\infty &= \frac{1}{1 + \exp(-(V + 59)/6.2)} \\
h_\infty &= \frac{1}{1 + \exp((V + 83)/4)} \\
\tau_m &= 0.22 / (\exp(-(V+132)/16.7) + \exp((V+16.8)/18.2)) + 0.13 \\
\tau_h &= 8.2 + \frac{56.6 + 0.27 \exp((V+115.2)/5)}{1 + \exp((V+86)/3.2)} \\
I_T &= -g_T \cdot m^{N_T} \cdot h \cdot (V - E_T)
Activation is fast; inactivation is slower. Channel deinactivates during
hyperpolarization, enabling rebound bursts.
- **I_h** (hyperpolarization-activated cation current):
.. math::
m_\infty &= \frac{1}{1 + \exp((V + 75)/5.5)} \\
\tau_m &= \frac{1}{\exp(-14.59 - 0.086V) + \exp(-1.87 + 0.0701V)} \\
I_h &= -g_h \cdot m \cdot (V - E_h)
Activates slowly at hyperpolarized potentials; provides depolarizing sag and
contributes to rebound.
**7. Spike Detection and Reset**
A spike occurs when ``ref_steps == 0`` and ``V >= theta``. On spike:
- V → E_Na (≈ +30 mV)
- theta → E_Na
- ref_steps → ceil(t_ref / dt) + 1
During the refractory period, I_spike drives V back toward E_K.
**8. Numerical Integration**
The model uses AdaptiveRungeKuttaStep with RKF45 (Runge-Kutta-Fehlberg 4(5)
adaptive integration). This matches NEST's GSL RKF45 integrator in terms of order
and adaptive step-size control.
**9. Conductance Units**
All conductances are **unitless** in this model. The membrane equation is written
as dV/dt = I/tau_m, meaning currents have units of mV/ms (i.e., they are already
divided by capacitance). Peak conductances g_peak_* scale the synaptic currents.
**10. Sleep-Wake Transitions**
The ht_neuron was designed to model thalamocortical neurons that exhibit two
distinct firing modes:
- **Tonic firing** (awake/depolarized): Regular spiking driven by excitatory input
and I_NaP.
- **Burst firing** (sleep/hyperpolarized): Rebound bursts mediated by I_T
deinactivation and I_h rebound.
By varying the balance of intrinsic current conductances (g_peak_T, g_peak_h,
g_peak_NaP) and background synaptic input, the model can transition between these
modes, reproducing the sleep-wake dynamics observed in thalamocortical circuits.
Examples
--------
**Example 1: Single neuron with injected current**
.. code-block:: python
>>> import brainpy as bp
>>> import brainpy.state as bps
>>> import saiunit as u
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>>
>>> # Create a single ht_neuron
>>> neuron = bps.ht_neuron(1, g_peak_T=0.0, g_peak_h=0.0)
>>>
>>> # Initialize state
>>> with bp.environ.context(dt=0.1 * u.ms):
... neuron.init_all_states()
...
... # Simulate 500 ms with step current
... currents = np.concatenate([
... np.zeros(1000),
... np.ones(3000) * 2.0, # 2 mV/ms injected current
... np.zeros(1000)
... ])
... voltages = []
... for I in currents:
... neuron.update(I)
... voltages.append(neuron.V.value[0])
>>>
>>> # Plot membrane potential
>>> times = np.arange(len(voltages)) * 0.1
>>> plt.figure(figsize=(10, 4))
>>> plt.plot(times, voltages)
>>> plt.xlabel('Time (ms)')
>>> plt.ylabel('Membrane potential (mV)')
>>> plt.title('ht_neuron response to step current')
>>> plt.show()
**Example 2: Rebound burst with I_T**
.. code-block:: python
>>> # Enable I_T for burst firing
>>> neuron = bps.ht_neuron(1, g_peak_T=1.0, g_peak_h=0.5)
>>>
>>> with bp.environ.context(dt=0.1 * u.ms):
... neuron.init_all_states()
...
... # Hyperpolarize, then release
... currents = np.concatenate([
... np.zeros(500),
... -np.ones(1000) * 3.0, # hyperpolarizing current
... np.zeros(1500)
... ])
... voltages = []
... for I in currents:
... neuron.update(I)
... voltages.append(neuron.V.value[0])
>>>
>>> # Observe rebound burst after hyperpolarization ends
>>> plt.figure(figsize=(10, 4))
>>> plt.plot(np.arange(len(voltages)) * 0.1, voltages)
>>> plt.xlabel('Time (ms)')
>>> plt.ylabel('V (mV)')
>>> plt.title('Rebound burst mediated by I_T')
>>> plt.show()
**Example 3: Multi-receptor synaptic input**
.. code-block:: python
>>> # Network with AMPA and NMDA receptors
>>> pre = bps.LIF(100, V_rest=-70*u.mV, V_th=-50*u.mV, V_reset=-70*u.mV,
... tau=20*u.ms, R=1*u.ohm, V_initializer=bp.init.Normal(-70, 5))
>>> post = bps.ht_neuron(10, g_peak_AMPA=0.1, g_peak_NMDA=0.05,
... instant_unblock_NMDA=False)
>>>
>>> # Create projections for AMPA and NMDA
>>> ampa_proj = bps.AlignPostProj(
... pre=pre, post=post,
... comm=bp.event.FixedProb(0.1, weight=1.0),
... syn=bps.Expon.desc(tau=2.4 * u.ms),
... label='AMPA'
... )
>>> nmda_proj = bps.AlignPostProj(
... pre=pre, post=post,
... comm=bp.event.FixedProb(0.1, weight=1.0),
... syn=bps.Expon.desc(tau=40 * u.ms),
... label='NMDA'
... )
>>>
>>> # Simulate network dynamics
>>> # (implementation depends on BrainPy network API)
See Also
--------
hh_psc_alpha : Hodgkin-Huxley neuron with alpha-shaped PSCs
iaf_cond_exp : Simple IAF with exponential conductance synapses
aeif_cond_alpha : Adaptive exponential IAF with alpha conductances
References
----------
.. [1] Hill S, Tononi G (2005). Modeling sleep and wakefulness in the
thalamocortical system. Journal of Neurophysiology, 93:1671-1698.
DOI: https://doi.org/10.1152/jn.00915.2004
.. [2] Vargas-Caballero M, Robinson HPC (2003). A slow fraction of Mg²⁺
unblock of NMDA receptors limits their contribution to spike
generation in cortical pyramidal neurons. Journal of Neurophysiology,
89:2778-2783. DOI: https://doi.org/10.1152/jn.01038.2002
"""
__module__ = 'brainpy.state'
# Synapse receptor type constants (matching NEST enum)
AMPA = 1
NMDA = 2
GABA_A = 3
GABA_B = 4
_MIN_H = 1e-8 # ms (dimensionless float — state variables are unitless)
_MAX_ITERS = 100000
def __init__(
self,
in_size: Size,
# Leak / reversal
E_Na: float = 30.0,
E_K: float = -90.0,
g_NaL: float = 0.2,
g_KL: float = 1.0,
tau_m: float = 16.0,
# Dynamic threshold
theta_eq: float = -51.0,
tau_theta: float = 2.0,
# Post-spike potassium current
tau_spike: float = 1.75,
t_ref: float = 2.0,
# AMPA synapse
g_peak_AMPA: float = 0.1,
tau_rise_AMPA: float = 0.5,
tau_decay_AMPA: float = 2.4,
E_rev_AMPA: float = 0.0,
# NMDA synapse
g_peak_NMDA: float = 0.075,
tau_rise_NMDA: float = 4.0,
tau_decay_NMDA: float = 40.0,
E_rev_NMDA: float = 0.0,
V_act_NMDA: float = -25.57,
S_act_NMDA: float = 0.081,
tau_Mg_slow_NMDA: float = 22.7,
tau_Mg_fast_NMDA: float = 0.68,
instant_unblock_NMDA: bool = False,
# GABA_A synapse
g_peak_GABA_A: float = 0.33,
tau_rise_GABA_A: float = 1.0,
tau_decay_GABA_A: float = 7.0,
E_rev_GABA_A: float = -70.0,
# GABA_B synapse
g_peak_GABA_B: float = 0.0132,
tau_rise_GABA_B: float = 60.0,
tau_decay_GABA_B: float = 200.0,
E_rev_GABA_B: float = -90.0,
# Intrinsic: I_NaP
g_peak_NaP: float = 1.0,
E_rev_NaP: float = 30.0,
N_NaP: float = 3.0,
# Intrinsic: I_KNa
g_peak_KNa: float = 1.0,
E_rev_KNa: float = -90.0,
tau_D_KNa: float = 1250.0,
# Intrinsic: I_T
g_peak_T: float = 1.0,
E_rev_T: float = 0.0,
N_T: float = 2.0,
# Intrinsic: I_h
g_peak_h: float = 1.0,
E_rev_h: float = -40.0,
# Testing
voltage_clamp: bool = False,
# Solver
gsl_error_tol: float = 1e-3,
# Base class
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
# Store all parameters as plain floats (unitless model, NEST convention)
self.E_Na = E_Na
self.E_K = E_K
self.g_NaL = g_NaL
self.g_KL = g_KL
self.tau_m = tau_m
self.theta_eq = theta_eq
self.tau_theta = tau_theta
self.tau_spike = tau_spike
self.t_ref = t_ref
self.g_peak_AMPA = g_peak_AMPA
self.tau_rise_AMPA = tau_rise_AMPA
self.tau_decay_AMPA = tau_decay_AMPA
self.E_rev_AMPA = E_rev_AMPA
self.g_peak_NMDA = g_peak_NMDA
self.tau_rise_NMDA = tau_rise_NMDA
self.tau_decay_NMDA = tau_decay_NMDA
self.E_rev_NMDA = E_rev_NMDA
self.V_act_NMDA = V_act_NMDA
self.S_act_NMDA = S_act_NMDA
self.tau_Mg_slow_NMDA = tau_Mg_slow_NMDA
self.tau_Mg_fast_NMDA = tau_Mg_fast_NMDA
self.instant_unblock_NMDA = instant_unblock_NMDA
self.g_peak_GABA_A = g_peak_GABA_A
self.tau_rise_GABA_A = tau_rise_GABA_A
self.tau_decay_GABA_A = tau_decay_GABA_A
self.E_rev_GABA_A = E_rev_GABA_A
self.g_peak_GABA_B = g_peak_GABA_B
self.tau_rise_GABA_B = tau_rise_GABA_B
self.tau_decay_GABA_B = tau_decay_GABA_B
self.E_rev_GABA_B = E_rev_GABA_B
self.g_peak_NaP = g_peak_NaP
self.E_rev_NaP = E_rev_NaP
self.N_NaP = N_NaP
self.g_peak_KNa = g_peak_KNa
self.E_rev_KNa = E_rev_KNa
self.tau_D_KNa = tau_D_KNa
self.g_peak_T = g_peak_T
self.E_rev_T = E_rev_T
self.N_T = N_T
self.g_peak_h = g_peak_h
self.E_rev_h = E_rev_h
self.voltage_clamp = voltage_clamp
self.gsl_error_tol = gsl_error_tol
self._validate_parameters()
# Pre-compute synaptic conductance step sizes
self._cond_step_AMPA = g_peak_AMPA * _beta_normalization_factor(tau_rise_AMPA, tau_decay_AMPA)
self._cond_step_NMDA = g_peak_NMDA * _beta_normalization_factor(tau_rise_NMDA, tau_decay_NMDA)
self._cond_step_GABA_A = g_peak_GABA_A * _beta_normalization_factor(tau_rise_GABA_A, tau_decay_GABA_A)
self._cond_step_GABA_B = g_peak_GABA_B * _beta_normalization_factor(tau_rise_GABA_B, tau_decay_GABA_B)
# Pre-compute refractory step count; dt_ms is a plain float (ms)
# because all state variables in ht_neuron are dimensionless and the
# RK weighted-sum s + h*acc must not introduce unit mismatches.
dt = brainstate.environ.get_dt()
dt_ms = float(u.math.asarray(dt / u.ms))
self._dt_ms = dt_ms
self.ref_count = int(round(self.t_ref / dt_ms))
# Adaptive RKF45 integrator — pass dimensionless dt/min_h so that
# h * k_derivative stays dimensionless (same units as state leaves).
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=dt_ms,
)
# Compute initial membrane potential (leak equilibrium) for voltage clamp
self._V_clamp = (self.g_NaL * self.E_Na + self.g_KL * self.E_K) / (self.g_NaL + self.g_KL)
def _validate_parameters(self):
r"""Validate parameter constraints to ensure physiological consistency.
Checks all parameter values against the same constraints enforced by NEST's
``ht_neuron::Parameters_::set()`` method. Raises ValueError if any constraint
is violated.
Raises
------
ValueError
If any conductance is negative, if any time constant is non-positive,
if S_act_NMDA < 0, if t_ref < 0, or if rise time >= decay time for
any synapse or NMDA Mg²⁺ unblocking kinetics.
Notes
-----
Enforced constraints:
1. **Non-negative conductances**: All g_peak_* and g_*L parameters must be >= 0.
2. **Positive time constants**: All tau_* parameters must be > 0.
3. **Rise < decay ordering**: For beta-function synapses and NMDA Mg²⁺ kinetics,
tau_rise must be strictly less than tau_decay to ensure proper biexponential
shape.
4. **Non-negative slope**: S_act_NMDA >= 0 (zero slope would disable voltage
dependence).
5. **Non-negative refractory period**: t_ref >= 0 (zero is allowed, disabling
refractory period).
This validation is called automatically during ``__init__`` to catch
configuration errors early.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.g_peak_AMPA, self.tau_m, self.S_act_NMDA, self.t_ref)):
return
# Non-negative peak conductances
for name in ('g_peak_AMPA', 'g_peak_NMDA', 'g_peak_GABA_A', 'g_peak_GABA_B',
'g_peak_NaP', 'g_peak_KNa', 'g_peak_T', 'g_peak_h',
'g_NaL', 'g_KL'):
if getattr(self, name) < 0:
raise ValueError(f'{name} >= 0 required.')
if self.S_act_NMDA < 0:
raise ValueError('S_act_NMDA >= 0 required.')
if self.t_ref < 0:
raise ValueError('t_ref >= 0 required.')
# Strictly positive time constants
for name in ('tau_rise_AMPA', 'tau_decay_AMPA',
'tau_rise_NMDA', 'tau_decay_NMDA',
'tau_rise_GABA_A', 'tau_decay_GABA_A',
'tau_rise_GABA_B', 'tau_decay_GABA_B',
'tau_Mg_fast_NMDA', 'tau_Mg_slow_NMDA',
'tau_spike', 'tau_theta', 'tau_m', 'tau_D_KNa'):
if getattr(self, name) <= 0:
raise ValueError(f'{name} > 0 required.')
# Rise < decay constraints
if self.tau_rise_AMPA >= self.tau_decay_AMPA:
raise ValueError('tau_rise_AMPA < tau_decay_AMPA required.')
if self.tau_rise_GABA_A >= self.tau_decay_GABA_A:
raise ValueError('tau_rise_GABA_A < tau_decay_GABA_A required.')
if self.tau_rise_GABA_B >= self.tau_decay_GABA_B:
raise ValueError('tau_rise_GABA_B < tau_decay_GABA_B required.')
if self.tau_rise_NMDA >= self.tau_decay_NMDA:
raise ValueError('tau_rise_NMDA < tau_decay_NMDA required.')
if self.tau_Mg_fast_NMDA >= self.tau_Mg_slow_NMDA:
raise ValueError('tau_Mg_fast_NMDA < tau_Mg_slow_NMDA required.')
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables to physiologically consistent equilibrium values.
Sets the membrane potential to the leak reversal potential (weighted average of
E_Na and E_K based on leak conductances), threshold to theta_eq, all synaptic
variables to zero, and all intrinsic gating variables to their voltage-dependent
equilibrium values at the initial membrane potential.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
**1. Initial Membrane Potential**
Computed from leak conductance balance:
.. math::
V_{init} = \frac{g_{NaL} \cdot E_{Na} + g_{KL} \cdot E_K}{g_{NaL} + g_{KL}}
With default parameters (g_NaL=0.2, g_KL=1.0, E_Na=30 mV, E_K=-90 mV):
.. math::
V_{init} = \frac{0.2 \cdot 30 + 1.0 \cdot (-90)}{0.2 + 1.0} = -70\ \text{mV}
**2. Threshold Initialization**
theta is set to theta_eq (default -51 mV).
**3. Synaptic Variables**
All beta-function state variables are initialized to zero:
- DG_AMPA, G_AMPA = 0
- DG_NMDA, G_NMDA = 0
- DG_GABA_A, G_GABA_A = 0
- DG_GABA_B, G_GABA_B = 0
**4. Intrinsic Gating Variables**
All gating variables are set to their steady-state values at V_init:
- m_fast_NMDA = m_slow_NMDA = m_∞^NMDA(V_init)
- m_Ih = m_∞^Ih(V_init)
- D_IKNa = D_∞(V_init)
- m_IT = m_∞^IT(V_init)
- h_IT = h_∞^IT(V_init)
At V_init ≈ -70 mV (resting potential):
- m_Ih ≈ 0.4 (partially activated, since I_h activates at hyperpolarization)
- m_IT ≈ 0.05 (mostly deactivated)
- h_IT ≈ 0.9 (mostly deinactivated, ready to support burst)
- D_IKNa ≈ 0.001 (minimal adaptation at rest)
- m_NMDA ≈ 0.01 (strongly blocked by Mg²⁺ at rest)
**5. Refractory Counter**
ref_steps = 0 (neuron is not refractory at initialization).
**6. Stimulation Current**
I_stim = 0 (no external current at t=0).
**7. Spike Time**
last_spike_time = -1e7 ms (far in the past, ensures no artificial refractory
period at simulation start).
**8. Voltage Clamp Value**
If voltage_clamp=True, _V_clamp is set to V_init and will be enforced during
all subsequent updates.
This initialization matches NEST's ``ht_neuron::State_::set()`` and
``calibrate()`` methods, ensuring consistent starting conditions for
simulation comparisons.
"""
# Compute initial membrane potential (leak equilibrium)
V_init = (self.g_NaL * self.E_Na + self.g_KL * self.E_K) / (self.g_NaL + self.g_KL)
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Use dimensionless dt_ms so that integration_step.value is a plain
# float; this prevents h * k_derivative from picking up ms units when
# the state leaves are dimensionless (would cause UnitMismatchError
# in AdaptiveRungeKuttaStep._rk_weighted_sum).
dt_ms = float(u.math.asarray(brainstate.environ.get_dt() / u.ms))
# Compute equilibrium values for intrinsic gating
m_nmda_init = _m_eq_NMDA(V_init, self.S_act_NMDA, self.V_act_NMDA)
m_ih_init = _m_eq_h(V_init)
d_ikna_init = _D_eq_KNa(V_init, self.tau_D_KNa)
m_it_init = _m_eq_T(V_init)
h_it_init = _h_eq_T(V_init)
# ODE state variables (unitless, mV)
self.V = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(V_init), self.varshape)
)
self.theta = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(self.theta_eq), self.varshape)
)
# Synaptic variables: all zero
self.DG_AMPA = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.G_AMPA = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.DG_NMDA = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.G_NMDA = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.DG_GABA_A = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.G_GABA_A = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.DG_GABA_B = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.G_GABA_B = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
# NMDA Mg²⁺ unblocking kinetics
self.m_fast_NMDA_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_nmda_init), self.varshape)
)
self.m_slow_NMDA_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_nmda_init), self.varshape)
)
# Intrinsic gating variables at equilibrium
self.m_Ih_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_ih_init), self.varshape)
)
self.D_IKNa_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(d_ikna_init), self.varshape)
)
self.m_IT_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_it_init), self.varshape)
)
self.h_IT_state = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(h_it_init), self.varshape)
)
# Intrinsic current values (for recording)
self.I_NaP_val = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.I_KNa_val = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.I_T_val = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
self.I_h_val = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0), self.varshape)
)
# Refractory counter
self.ref_steps = brainstate.ShortTermState(
jnp.zeros(self.varshape, dtype=ditype)
)
# Stimulation current buffer
self.I_stim = brainstate.ShortTermState(
jnp.zeros(self.varshape, dtype=dftype)
)
# Spike time tracking
self.last_spike_time = brainstate.ShortTermState(
u.math.full(self.varshape, -1e7 * u.ms)
)
# Integration step size (dimensionless float in ms — must match the
# dimensionless dt passed to AdaptiveRungeKuttaStep so that h*k stays
# in the same units as the state leaves, which are all dimensionless).
self.integration_step = brainstate.ShortTermState.init(
braintools.init.Constant(dt_ms), self.varshape
)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate differentiable spike output using surrogate gradient function.
Converts the discrete spike condition (V >= theta) into a continuous,
differentiable output suitable for gradient-based optimization. The voltage
is scaled relative to the dynamic threshold before passing through the
surrogate function.
Parameters
----------
V : ArrayLike or None, default=None
Membrane potential in mV. If None, uses self.V.value. Can be a scalar,
1D, or multi-dimensional array matching the neuron population shape.
For explicit spike injection (e.g., after reset), pass a manually set
value (e.g., large positive for spike, large negative for no spike).
Returns
-------
ArrayLike
Surrogate spike output with the same shape as V. Values near 1.0 indicate
spike, values near 0.0 indicate no spike. Gradients flow through the
surrogate function (e.g., ReluGrad, sigmoid, etc.) rather than being zero.
Notes
-----
**1. Voltage Scaling**
The input voltage is scaled by the threshold magnitude to normalize the
surrogate function input:
.. math::
v_{scaled} = \frac{V - \theta}{\max(|\theta_{eq}|, 1)}
This ensures that v_scaled ≈ 0 when V ≈ theta, and v_scaled > 0 when spiking.
The denominator prevents numerical issues if theta_eq is very small.
**2. Surrogate Function**
The scaled voltage is passed through the surrogate gradient function specified
during initialization (default: braintools.surrogate.ReluGrad()):
.. math::
s = \text{spk\_fun}(v_{scaled})
During forward pass, this typically produces a Heaviside-like step (0 or 1).
During backward pass, the gradient is replaced by a smooth approximation
(e.g., d/dv max(0, v) = 1 if v > 0 else 0 for ReluGrad).
**3. Spike Detection vs. Spike Output**
This method generates the *output* for surrogate gradient learning. The actual
spike detection (threshold crossing, reset, refractory logic) happens in the
``update()`` method using discrete logic. The two are synchronized: when
``update()`` detects a spike, it calls ``get_spike()`` with a manually set
V value to ensure the output is 1.0.
**4. Gradient Flow**
Unlike a true Heaviside function (which has zero gradient everywhere except
at the discontinuity), the surrogate function provides non-zero gradients in
a neighborhood around the threshold, enabling backpropagation through spiking
networks.
Examples
--------
.. code-block:: python
>>> import brainpy.state as bps
>>> import saiunit as u
>>> import brainstate
>>>
>>> neuron = bps.ht_neuron(1)
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neuron.init_all_states()
...
... # Check spike output at rest (V ≈ -70 mV, theta ≈ -51 mV)
... # V < theta, so no spike
... spike = neuron.get_spike()
... print(spike) # ≈ 0.0
...
... # Manually set V above threshold
... neuron.V.value = -45.0 # > theta
... spike = neuron.get_spike()
... print(spike) # ≈ 1.0 (depends on surrogate function)
"""
V = self.V.value if V is None else V
theta = self.theta.value
# Scale: positive when V >= theta
v_scaled = (V - theta) / max(abs(self.theta_eq), 1.0)
return self.spk_fun(v_scaled)
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
Parameters
----------
state : DotDict
Keys: V_m, theta, DG_AMPA, G_AMPA, DG_NMDA, G_NMDA,
DG_GABA_A, G_GABA_A, DG_GABA_B, G_GABA_B,
m_fast_NMDA, m_slow_NMDA, m_Ih, D_IKNa, m_IT, h_IT
-- ODE state variables (all unitless floats, mV-scale).
extra : DotDict
Keys: spike_mask, r, unstable, i_stim, V_clamp_val
-- mutable auxiliary data carried through the integrator.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
is_refractory = extra.r > 0
V = jnp.where(self.voltage_clamp, extra.V_clamp_val, state.V_m)
# NMDA conductance with instantaneous blocking (clamp m_fast, m_slow to m_eq)
m_eq_nmda = 1.0 / (1.0 + jnp.exp(-self.S_act_NMDA * (V - self.V_act_NMDA)))
mf = jnp.minimum(m_eq_nmda, state.m_fast_NMDA)
ms = jnp.minimum(m_eq_nmda, state.m_slow_NMDA)
if self.instant_unblock_NMDA:
m_nmda = m_eq_nmda
else:
A1 = 0.51 - 0.0028 * V
A2 = 1.0 - A1
m_nmda = A1 * mf + A2 * ms
# Synaptic currents: I = -g * (V - E)
I_syn = (
-state.G_AMPA * (V - self.E_rev_AMPA)
- state.G_NMDA * m_nmda * (V - self.E_rev_NMDA)
- state.G_GABA_A * (V - self.E_rev_GABA_A)
- state.G_GABA_B * (V - self.E_rev_GABA_B)
)
# Post-spike K current (only during refractory)
I_spike = jnp.where(is_refractory, -(V - self.E_K) / self.tau_spike, 0.0)
# Leak currents
I_Na = -self.g_NaL * (V - self.E_Na)
I_K_leak = -self.g_KL * (V - self.E_K)
# I_NaP (persistent sodium)
INaP_thresh = -55.7
INaP_slope = 7.7
m_inf_NaP = 1.0 / (1.0 + jnp.exp(-(V - INaP_thresh) / INaP_slope))
i_NaP = -self.g_peak_NaP * (m_inf_NaP ** self.N_NaP) * (V - self.E_rev_NaP)
# I_KNa (depolarization-activated K)
d_half = 0.25
d_val = state.D_IKNa
m_inf_KNa = jnp.where(
d_val > 0,
1.0 / (1.0 + (d_half / jnp.maximum(d_val, 1e-30)) ** 3.5),
0.0
)
i_KNa = -self.g_peak_KNa * m_inf_KNa * (V - self.E_rev_KNa)
# I_T (low-threshold Ca)
i_T = -self.g_peak_T * (state.m_IT ** self.N_T) * state.h_IT * (V - self.E_rev_T)
# I_h (hyperpolarization-activated)
i_h = -self.g_peak_h * state.m_Ih * (V - self.E_rev_h)
# dV/dt
dV_raw = (I_Na + I_K_leak + I_syn + i_NaP + i_KNa + i_T + i_h + extra.i_stim) / self.tau_m + I_spike
# d(theta)/dt
d_theta = -(state.theta - self.theta_eq) / self.tau_theta
# AMPA synapse
d_DG_AMPA = -state.DG_AMPA / self.tau_rise_AMPA
d_G_AMPA = state.DG_AMPA - state.G_AMPA / self.tau_decay_AMPA
# NMDA synapse
d_DG_NMDA = -state.DG_NMDA / self.tau_rise_NMDA
d_G_NMDA = state.DG_NMDA - state.G_NMDA / self.tau_decay_NMDA
d_m_fast_NMDA = (m_eq_nmda - mf) / self.tau_Mg_fast_NMDA
d_m_slow_NMDA = (m_eq_nmda - ms) / self.tau_Mg_slow_NMDA
# GABA_A synapse
d_DG_GABA_A = -state.DG_GABA_A / self.tau_rise_GABA_A
d_G_GABA_A = state.DG_GABA_A - state.G_GABA_A / self.tau_decay_GABA_A
# GABA_B synapse
d_DG_GABA_B = -state.DG_GABA_B / self.tau_rise_GABA_B
d_G_GABA_B = state.DG_GABA_B - state.G_GABA_B / self.tau_decay_GABA_B
# I_KNa D variable
D_influx_peak = 0.025
D_thresh = -10.0
D_slope = 5.0
D_eq = 0.001
D_influx = D_influx_peak / (1.0 + jnp.exp(-(V - D_thresh) / D_slope))
D_eq_val = self.tau_D_KNa * D_influx + D_eq
d_D_IKNa = (D_eq_val - state.D_IKNa) / self.tau_D_KNa
# I_T gating
tau_m_T = 0.22 / (jnp.exp(-(V + 132.0) / 16.7) + jnp.exp((V + 16.8) / 18.2)) + 0.13
tau_h_T = 8.2 + (56.6 + 0.27 * jnp.exp((V + 115.2) / 5.0)) / (1.0 + jnp.exp((V + 86.0) / 3.2))
m_eq_t = 1.0 / (1.0 + jnp.exp(-(V + 59.0) / 6.2))
h_eq_t = 1.0 / (1.0 + jnp.exp((V + 83.0) / 4.0))
d_m_IT = (m_eq_t - state.m_IT) / tau_m_T
d_h_IT = (h_eq_t - state.h_IT) / tau_h_T
# I_h gating
tau_m_h = 1.0 / (jnp.exp(-14.59 - 0.086 * V) + jnp.exp(-1.87 + 0.0701 * V))
I_h_Vthreshold = -75.0
m_eq_ih = 1.0 / (1.0 + jnp.exp((V - I_h_Vthreshold) / 5.5))
d_m_Ih = (m_eq_ih - state.m_Ih) / tau_m_h
return DotDict(
V_m=dV_raw, theta=d_theta,
DG_AMPA=d_DG_AMPA, G_AMPA=d_G_AMPA,
DG_NMDA=d_DG_NMDA, G_NMDA=d_G_NMDA,
DG_GABA_A=d_DG_GABA_A, G_GABA_A=d_G_GABA_A,
DG_GABA_B=d_DG_GABA_B, G_GABA_B=d_G_GABA_B,
m_fast_NMDA=d_m_fast_NMDA, m_slow_NMDA=d_m_slow_NMDA,
m_Ih=d_m_Ih, D_IKNa=d_D_IKNa,
m_IT=d_m_IT, h_IT=d_h_IT,
)
def _event_fn(self, state, extra, accept):
"""In-loop spike detection, reset, and refractory handling.
Parameters
----------
state : DotDict
Keys: V_m, theta, DG_AMPA, G_AMPA, DG_NMDA, G_NMDA,
DG_GABA_A, G_GABA_A, DG_GABA_B, G_GABA_B,
m_fast_NMDA, m_slow_NMDA, m_Ih, D_IKNa, m_IT, h_IT
-- ODE state variables.
extra : DotDict
Keys: spike_mask, r, unstable, i_stim, V_clamp_val.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(new_state, new_extra) DotDicts with updated spike/reset/refractory info.
"""
unstable = extra.unstable | jnp.any(
accept & ((state.V_m < -1e3) | (state.V_m > 1e3))
)
# Enforce voltage clamp after accepted step
new_V = jnp.where(self.voltage_clamp & accept, extra.V_clamp_val, state.V_m)
# Enforce instantaneous NMDA blocking (m_fast, m_slow cannot exceed m_eq)
m_eq_nmda_final = 1.0 / (1.0 + jnp.exp(-self.S_act_NMDA * (new_V - self.V_act_NMDA)))
new_m_fast = jnp.minimum(m_eq_nmda_final, state.m_fast_NMDA)
new_m_slow = jnp.minimum(m_eq_nmda_final, state.m_slow_NMDA)
# Spike detection: ref_steps == 0 and V >= theta
spike_now = accept & (extra.r <= 0) & (new_V >= state.theta)
spike_mask = extra.spike_mask | spike_now
# On spike: V -> E_Na, theta -> E_Na, ref_steps -> ref_count + 1
new_V = jnp.where(spike_now, self.E_Na, new_V)
new_theta = jnp.where(spike_now, self.E_Na, state.theta)
r = jnp.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r)
new_state = DotDict(
{**state,
'V_m': new_V, 'theta': new_theta,
'm_fast_NMDA': new_m_fast, 'm_slow_NMDA': new_m_slow}
)
new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable})
return new_state, new_extra
[docs]
def update(self, x=0.0):
r"""Advance neuron state by one simulation time step with adaptive ODE integration.
Performs a complete update cycle for the ht_neuron model, including: (1) adaptive
RKF45 integration of the 16-dimensional ODE system, (2) spike detection and reset,
(3) refractory period management, (4) synaptic input processing, and (5) external
current buffering. The update sequence matches NEST's ``ht_neuron::update()``
implementation for numerical consistency.
Parameters
----------
x : float or ArrayLike, default=0.0
External stimulation current in mV/ms (since conductances are unitless in
this model, currents are expressed as I/C_m). Can be:
- **Scalar**: Applied uniformly to all neurons in the population.
- **Array**: Shape must broadcast to the neuron population shape (e.g., for
spatially varying input or per-neuron stimulation protocols).
This current is added to the membrane equation as I_stim and affects dV/dt
during the next integration step.
Returns
-------
ArrayLike
Differentiable spike output with shape matching the neuron population. Values
near 1.0 indicate spike, near 0.0 indicate no spike. Compatible with surrogate
gradient-based learning.
Notes
-----
**1. Update Sequence**
The update follows this precise ordering (matching NEST):
**Step 1: ODE Integration**
Integrate the 16-dimensional state vector from t to t+dt using adaptive RKF45
via AdaptiveRungeKuttaStep. The state vector contains:
.. math::
\mathbf{y} = [V, \theta, DG_{AMPA}, G_{AMPA}, DG_{NMDA}, G_{NMDA},
DG_{GABA_A}, G_{GABA_A}, DG_{GABA_B}, G_{GABA_B},
m_{fast}^{NMDA}, m_{slow}^{NMDA}, m_{Ih}, D_{IKNa},
m_{IT}, h_{IT}]
The ODE right-hand side (defined in ``_vector_field``) computes:
- Membrane potential derivative from leak, synaptic, intrinsic, and stimulation
currents, plus post-spike repolarization if refractory
- Threshold relaxation: dθ/dt = -(θ - θ_eq)/tau_θ
- Beta-function synaptic conductance dynamics (4 receptor types)
- NMDA Mg²⁺ unblocking kinetics (fast and slow components)
- Intrinsic gating variable dynamics (I_h, I_T, I_KNa)
**Step 2: Post-Integration Constraints**
After integration, enforce:
- **Voltage clamp**: If voltage_clamp=True, reset V to _V_clamp.
- **Instantaneous NMDA blocking**: Clamp m_fast_NMDA and m_slow_NMDA to not
exceed m_∞^NMDA(V), ensuring the Mg²⁺ block cannot be "overshot" during
adaptive time steps.
**Step 3: Spike Detection and Reset**
If ``ref_steps == 0`` and ``V >= theta``, a spike is generated:
- V → E_Na (≈ +30 mV)
- θ → E_Na
- ref_steps → ceil(t_ref / dt) + 1
- spike_flag = True
**Step 4: Refractory Counter Decrement**
If ref_steps > 0, decrement by 1. This happens *after* spike detection, so a
neuron that just spiked will spend t_ref ms refractory.
**Step 5: Synaptic Spike Input Delivery**
Add arriving spikes to the DG (derivative of conductance) variables. Inputs are
retrieved from delta_inputs with labels 'AMPA', 'NMDA', 'GABA_A', 'GABA_B':
.. math::
DG_{receptor} \mathrel{+}= g_{peak,receptor} \cdot \text{norm} \cdot w \cdot N_{spikes}
Unlabeled delta inputs default to AMPA.
**Step 6: Stimulation Current Buffering**
Store the input current ``x`` in ``I_stim`` for use in the *next* update cycle.
This matches NEST's one-step delay for external currents.
**2. Refractory Dynamics**
During the refractory period (ref_steps > 0), the neuron cannot spike, and the
post-spike potassium current is active:
.. math::
I_{spike} = -\frac{V - E_K}{\tau_{spike}}
This drives V toward E_K (hyperpolarization) with time constant tau_spike.
**3. Synaptic Input Routing**
The ht_neuron expects delta inputs to be labeled by receptor type. Projections
should add inputs via:
.. code-block:: python
post.add_delta_input(weight * pre_spike, label='AMPA')
If no label is provided, inputs accumulate in the generic delta_inputs and are
routed to AMPA by default.
**4. Numerical Considerations**
- **Adaptive integration**: The RKF45 solver uses variable step sizes to maintain
accuracy. Typical internal steps are ~0.01-0.1 ms depending on voltage dynamics.
- **Vectorized integration**: All neurons in the population are integrated
simultaneously using JAX vectorized operations via AdaptiveRungeKuttaStep.
- **Intrinsic current caching**: Intrinsic currents (I_NaP, I_KNa, I_T, I_h) are
computed after integration and stored in separate state variables for recording.
**5. Gradient Compatibility**
The integration uses JAX-based AdaptiveRungeKuttaStep, enabling automatic
differentiation through the integration. Combined with surrogate gradient
spike output, the model supports end-to-end backpropagation.
Warnings
--------
- **Unlabeled inputs default to AMPA**: If you send synaptic inputs without
specifying a receptor label, they will be routed to AMPA receptors by default.
This may produce unexpected results if you intended NMDA, GABA_A, or GABA_B.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Read state variables
V_m = self.V.value
theta_val = self.theta.value
DG_AMPA = self.DG_AMPA.value
G_AMPA = self.G_AMPA.value
DG_NMDA = self.DG_NMDA.value
G_NMDA = self.G_NMDA.value
DG_GABA_A = self.DG_GABA_A.value
G_GABA_A = self.G_GABA_A.value
DG_GABA_B = self.DG_GABA_B.value
G_GABA_B = self.G_GABA_B.value
m_fast = self.m_fast_NMDA_state.value
m_slow = self.m_slow_NMDA_state.value
m_Ih = self.m_Ih_state.value
D_IKNa = self.D_IKNa_state.value
m_IT = self.m_IT_state.value
h_IT = self.h_IT_state.value
r = self.ref_steps.value
i_stim = self.I_stim.value
h = self.integration_step.value
# Build ODE state and extra DotDicts
ode_state = DotDict(
V_m=V_m, theta=theta_val,
DG_AMPA=DG_AMPA, G_AMPA=G_AMPA,
DG_NMDA=DG_NMDA, G_NMDA=G_NMDA,
DG_GABA_A=DG_GABA_A, G_GABA_A=G_GABA_A,
DG_GABA_B=DG_GABA_B, G_GABA_B=G_GABA_B,
m_fast_NMDA=m_fast, m_slow_NMDA=m_slow,
m_Ih=m_Ih, D_IKNa=D_IKNa,
m_IT=m_IT, h_IT=h_IT,
)
extra = DotDict(
spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_),
r=r,
unstable=jnp.array(False),
i_stim=i_stim,
V_clamp_val=jnp.full(self.varshape, self._V_clamp, dtype=dftype),
)
# Adaptive RKF45 integration
ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra)
spike_mask = extra.spike_mask
r = extra.r
unstable = extra.unstable
# Post-loop stability check
brainstate.transform.jit_error_if(
jnp.any(unstable), 'Numerical instability in ht_neuron dynamics.'
)
# Decrement refractory counter
r = jnp.where(r > 0, r - 1, r)
# Collect synaptic spike inputs
spk_ampa = self.sum_delta_inputs(jnp.zeros(self.varshape, dtype=dftype), label='AMPA')
spk_nmda = self.sum_delta_inputs(jnp.zeros(self.varshape, dtype=dftype), label='NMDA')
spk_gaba_a = self.sum_delta_inputs(jnp.zeros(self.varshape, dtype=dftype), label='GABA_A')
spk_gaba_b = self.sum_delta_inputs(jnp.zeros(self.varshape, dtype=dftype), label='GABA_B')
# Also collect unlabeled delta inputs (generic spikes go to AMPA by default)
unlabeled = self.sum_delta_inputs(jnp.zeros(self.varshape, dtype=dftype))
spk_ampa = spk_ampa + unlabeled
# Apply synaptic spike inputs to DG variables
DG_AMPA_out = ode_state.DG_AMPA + self._cond_step_AMPA * spk_ampa
DG_NMDA_out = ode_state.DG_NMDA + self._cond_step_NMDA * spk_nmda
DG_GABA_A_out = ode_state.DG_GABA_A + self._cond_step_GABA_A * spk_gaba_a
DG_GABA_B_out = ode_state.DG_GABA_B + self._cond_step_GABA_B * spk_gaba_b
# Compute intrinsic currents for recording (post-integration snapshot)
V_final = ode_state.V_m
INaP_thresh = -55.7
INaP_slope = 7.7
m_inf_NaP = 1.0 / (1.0 + jnp.exp(-(V_final - INaP_thresh) / INaP_slope))
I_NaP_final = -self.g_peak_NaP * (m_inf_NaP ** self.N_NaP) * (V_final - self.E_rev_NaP)
d_half = 0.25
d_val = ode_state.D_IKNa
m_inf_KNa = jnp.where(
d_val > 0,
1.0 / (1.0 + (d_half / jnp.maximum(d_val, 1e-30)) ** 3.5),
0.0
)
I_KNa_final = -self.g_peak_KNa * m_inf_KNa * (V_final - self.E_rev_KNa)
I_T_final = -self.g_peak_T * (ode_state.m_IT ** self.N_T) * ode_state.h_IT * (V_final - self.E_rev_T)
I_h_final = -self.g_peak_h * ode_state.m_Ih * (V_final - self.E_rev_h)
# Current input for next step (one-step delay)
new_i_stim = jnp.broadcast_to(jnp.asarray(x, dtype=dftype), self.varshape)
# Write back state
self.V.value = ode_state.V_m
self.theta.value = ode_state.theta
self.DG_AMPA.value = DG_AMPA_out
self.G_AMPA.value = ode_state.G_AMPA
self.DG_NMDA.value = DG_NMDA_out
self.G_NMDA.value = ode_state.G_NMDA
self.DG_GABA_A.value = DG_GABA_A_out
self.G_GABA_A.value = ode_state.G_GABA_A
self.DG_GABA_B.value = DG_GABA_B_out
self.G_GABA_B.value = ode_state.G_GABA_B
self.m_fast_NMDA_state.value = ode_state.m_fast_NMDA
self.m_slow_NMDA_state.value = ode_state.m_slow_NMDA
self.m_Ih_state.value = ode_state.m_Ih
self.D_IKNa_state.value = ode_state.D_IKNa
self.m_IT_state.value = ode_state.m_IT
self.h_IT_state.value = ode_state.h_IT
# Intrinsic currents
self.I_NaP_val.value = I_NaP_final
self.I_KNa_val.value = I_KNa_final
self.I_T_val.value = I_T_final
self.I_h_val.value = I_h_final
# Refractory counter
self.ref_steps.value = jnp.asarray(r, dtype=ditype)
self.integration_step.value = h
self.I_stim.value = new_i_stim
# Spike time update
last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
# Return spike output via surrogate gradient
V_spike = jnp.where(spike_mask, 1e-12, -1.0)
return self.get_spike(V_spike)