# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep
__all__ = [
'hh_psc_alpha',
]
def _hh_psc_alpha_equilibrium(V):
r"""Compute equilibrium values of Hodgkin-Huxley gating variables.
Calculates steady-state activation and inactivation at a given membrane
potential using the voltage-dependent rate functions from NEST's
``hh_psc_alpha`` model. These equilibria are used for state initialization
when explicit gating values are not provided.
Parameters
----------
V : float
Membrane potential in mV (unitless, not a ``saiunit`` quantity).
Returns
-------
tuple of float
``(m_inf, h_inf, n_inf)`` — equilibrium values (dimensionless) for
Na activation, Na inactivation, and K activation, respectively.
Each is in [0, 1].
Notes
-----
Uses the Hodgkin-Huxley rate functions with sign conventions matching
NEST's implementation:
.. math::
\alpha_n = \frac{0.01(V + 55)}{1 - e^{-(V+55)/10}}, \quad
\beta_n = 0.125 e^{-(V+65)/80}
.. math::
\alpha_m = \frac{0.1(V + 40)}{1 - e^{-(V+40)/10}}, \quad
\beta_m = 4 e^{-(V+65)/18}
.. math::
\alpha_h = 0.07 e^{-(V+65)/20}, \quad
\beta_h = \frac{1}{1 + e^{-(V+35)/10}}
Equilibrium is :math:`x_\infty = \alpha_x / (\alpha_x + \beta_x)`.
"""
alpha_n = (0.01 * (V + 55.0)) / (1.0 - np.exp(-(V + 55.0) / 10.0))
beta_n = 0.125 * np.exp(-(V + 65.0) / 80.0)
alpha_m = (0.1 * (V + 40.0)) / (1.0 - np.exp(-(V + 40.0) / 10.0))
beta_m = 4.0 * np.exp(-(V + 65.0) / 18.0)
alpha_h = 0.07 * np.exp(-(V + 65.0) / 20.0)
beta_h = 1.0 / (1.0 + np.exp(-(V + 35.0) / 10.0))
m_inf = alpha_m / (alpha_m + beta_m)
h_inf = alpha_h / (alpha_h + beta_h)
n_inf = alpha_n / (alpha_n + beta_n)
return m_inf, h_inf, n_inf
class hh_psc_alpha(NESTNeuron):
r"""NEST-compatible Hodgkin-Huxley neuron with alpha-shaped postsynaptic currents.
Current-based spiking neuron using the Hodgkin-Huxley formalism with
voltage-gated sodium and potassium channels, leak conductance, alpha-function
postsynaptic currents, threshold-and-local-maximum spike detection, and an
explicit refractory period that suppresses spike emission only (subthreshold
dynamics continue freely). Follows NEST ``models/hh_psc_alpha.{h,cpp}``
implementation with adaptive Runge-Kutta integration (RK45).
**1. Mathematical Model**
**Membrane and ionic current dynamics:**
The membrane potential evolves as
.. math::
C_m \frac{dV_m}{dt} = -(I_{Na} + I_K + I_L) + I_{stim} + I_e
+ I_{syn,ex} + I_{syn,in}
where
.. math::
I_{Na} &= g_{Na}\, m^3\, h\, (V_m - E_{Na}) \\
I_K &= g_K\, n^4\, (V_m - E_K) \\
I_L &= g_L\, (V_m - E_L)
Gating variables :math:`m` (Na activation), :math:`h` (Na inactivation),
:math:`n` (K activation) obey first-order kinetics:
.. math::
\frac{dx}{dt} = \alpha_x(V)(1 - x) - \beta_x(V)\,x
with voltage-dependent rate functions (voltage :math:`V` in mV, rates in 1/ms):
.. math::
\alpha_n &= \frac{0.01\,(V + 55)}{1 - e^{-(V+55)/10}}, \quad
\beta_n = 0.125\,e^{-(V+65)/80} \\
\alpha_m &= \frac{0.1\,(V + 40)}{1 - e^{-(V+40)/10}}, \quad
\beta_m = 4\,e^{-(V+65)/18} \\
\alpha_h &= 0.07\,e^{-(V+65)/20}, \quad
\beta_h = \frac{1}{1 + e^{-(V+35)/10}}
**Alpha-function synaptic currents:**
Each synapse type (excitatory/inhibitory) is modelled as a second-order
linear system producing an alpha-shaped postsynaptic current:
.. math::
\frac{dI_{syn}}{dt} &= dI_{syn} - \frac{I_{syn}}{\tau_{syn}} \\
\frac{d(dI_{syn})}{dt} &= -\frac{dI_{syn}}{\tau_{syn}}
A spike arriving with weight :math:`w` (in pA) adds
:math:`w \cdot e / \tau_{syn}` to :math:`dI_{syn}`, normalizing the
peak current to :math:`w` for :math:`w = 1`. Incoming spike weights are
split by sign: positive weights drive excitatory state (:math:`dI_{syn,ex}`),
negative weights drive inhibitory state (:math:`dI_{syn,in}`).
**2. Spike Detection and Refractory Handling**
A spike is detected when the membrane potential crosses 0 mV from below
**and** a local maximum is detected (i.e., the potential starts decreasing).
Formally, a spike is emitted when:
1. ``refractory_step_count == 0`` (not in refractory period), **and**
2. ``V_m >= 0 mV`` (threshold crossing), **and**
3. ``V_old > V_m`` (local maximum — potential is now falling).
Unlike integrate-and-fire models, **no voltage reset occurs**. The potassium
current naturally repolarizes the membrane after a spike. During the
refractory period :math:`t_{ref}`, spike emission is suppressed but all
state variables continue evolving according to their differential equations.
**3. Update Order Per Simulation Step**
The update follows NEST's exact order:
1. Record pre-integration membrane potential (``V_old``).
2. Integrate the full 8-dimensional ODE system
:math:`(V_m, m, h, n, dI_{ex}, I_{ex}, dI_{in}, I_{in})` over one
time step :math:`[t, t+dt]` using adaptive RK45 (Dormand-Prince).
3. Add arriving synaptic spike inputs to :math:`dI_{syn,ex}` and
:math:`dI_{syn,in}`.
4. Check spike condition: ``V_m >= 0 and V_old > V_m and r == 0``.
5. Update refractory counter and record spike time.
6. Store buffered external stimulation current for the next step.
**4. Numerical Integration**
Uses ``AdaptiveRungeKuttaStep`` with method ``'RKF45'`` to match NEST's
GSL RKF45 adaptive integrator. Default tolerance is ``gsl_error_tol=1e-3``.
All neurons are integrated simultaneously in a vectorized fashion.
**5. Assumptions, Constraints, and Computational Implications**
- ``C_m > 0``, ``g_Na >= 0``, ``g_K >= 0``, ``g_L >= 0``,
``tau_syn_ex > 0``, ``tau_syn_in > 0``, and ``t_ref >= 0`` are enforced
at construction.
- External current ``update(x=...)`` is buffered for one step, matching
NEST ring-buffer semantics.
- The adaptive RK45 integrator performs vectorized integration across all
neurons simultaneously using JAX operations.
- Spike detection uses a local maximum criterion rather than threshold
crossing alone, matching biological action potential dynamics.
Parameters
----------
in_size : Size
Population shape specification. All per-neuron parameters are broadcast
to ``self.varshape`` derived from ``in_size``.
E_L : ArrayLike, optional
Leak reversal potential :math:`E_L` in mV; scalar or array broadcastable
to ``self.varshape``. Determines resting potential. Default is
``-54.402 * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF; broadcastable to ``self.varshape``
and strictly positive. Default is ``100. * u.pF``.
g_Na : ArrayLike, optional
Sodium peak conductance :math:`g_{Na}` in nS; broadcastable to
``self.varshape`` and non-negative. Default is ``12000. * u.nS``.
g_K : ArrayLike, optional
Potassium peak conductance :math:`g_K` in nS; broadcastable to
``self.varshape`` and non-negative. Default is ``3600. * u.nS``.
g_L : ArrayLike, optional
Leak conductance :math:`g_L` in nS; broadcastable to ``self.varshape``
and non-negative. Default is ``30. * u.nS``.
E_Na : ArrayLike, optional
Sodium reversal potential :math:`E_{Na}` in mV; broadcastable to
``self.varshape``. Default is ``50. * u.mV``.
E_K : ArrayLike, optional
Potassium reversal potential :math:`E_K` in mV; broadcastable to
``self.varshape``. Default is ``-77. * u.mV``.
t_ref : ArrayLike, optional
Absolute refractory period :math:`t_{ref}` in ms; broadcastable to
``self.varshape`` and non-negative. Converted to integer step counts by
``ceil(t_ref / dt)``. Default is ``2. * u.ms``.
tau_syn_ex : ArrayLike, optional
Excitatory alpha time constant :math:`\tau_{syn,ex}` in ms; broadcastable
to ``self.varshape`` and strictly positive. Default is ``0.2 * u.ms``.
tau_syn_in : ArrayLike, optional
Inhibitory alpha time constant :math:`\tau_{syn,in}` in ms; broadcastable
to ``self.varshape`` and strictly positive. Default is ``2. * u.ms``.
I_e : ArrayLike, optional
Constant injected current :math:`I_e` in pA; scalar or array broadcastable
to ``self.varshape``. Default is ``0. * u.pA``.
V_m_init : ArrayLike, optional
Initial membrane potential in mV; broadcastable to ``self.varshape``.
Default is ``-65. * u.mV``.
Act_m_init : ArrayLike or None, optional
Initial Na activation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
Inact_h_init : ArrayLike or None, optional
Initial Na inactivation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
Act_n_init : ArrayLike or None, optional
Initial K activation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
spk_fun : Callable, optional
Surrogate spike nonlinearity used by :meth:`get_spike` for differentiable
spike generation. Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
``'hard'`` applies stop-gradient to match NEST hard reset semantics.
Default is ``'hard'``.
gsl_error_tol : float, optional
Unitless local RKF45 error tolerance, strictly positive.
Default is ``1e-3``.
name : str or None, optional
Optional node name for debugging and visualization.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 27 14 16 36
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar/tuple
- required
- --
- Defines neuron population shape ``self.varshape``.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-54.402 * u.mV``
- :math:`E_L`
- Leak reversal potential (resting potential).
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``100. * u.pF``
- :math:`C_m`
- Membrane capacitance.
* - ``g_Na``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``12000. * u.nS``
- :math:`g_{Na}`
- Sodium peak conductance.
* - ``g_K``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``3600. * u.nS``
- :math:`g_K`
- Potassium peak conductance.
* - ``g_L``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``30. * u.nS``
- :math:`g_L`
- Leak conductance.
* - ``E_Na``
- ArrayLike, broadcastable (mV)
- ``50. * u.mV``
- :math:`E_{Na}`
- Sodium reversal potential.
* - ``E_K``
- ArrayLike, broadcastable (mV)
- ``-77. * u.mV``
- :math:`E_K`
- Potassium reversal potential.
* - ``t_ref``
- ArrayLike, broadcastable (ms), ``>= 0``
- ``2. * u.ms``
- :math:`t_{ref}`
- Absolute refractory period duration.
* - ``tau_syn_ex``
- ArrayLike, broadcastable (ms), ``> 0``
- ``0.2 * u.ms``
- :math:`\tau_{syn,ex}`
- Excitatory alpha-kernel time constant.
* - ``tau_syn_in``
- ArrayLike, broadcastable (ms), ``> 0``
- ``2. * u.ms``
- :math:`\tau_{syn,in}`
- Inhibitory alpha-kernel time constant.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant external current added every step.
* - ``V_m_init``
- ArrayLike, broadcastable (mV)
- ``-65. * u.mV``
- --
- Initial membrane potential.
* - ``Act_m_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial Na activation; ``None`` uses equilibrium at ``V_m_init``.
* - ``Inact_h_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial Na inactivation; ``None`` uses equilibrium at ``V_m_init``.
* - ``Act_n_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial K activation; ``None`` uses equilibrium at ``V_m_init``.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate gradient function for spike generation.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset mode; ``'hard'`` stops gradient through reset.
* - ``gsl_error_tol``
- float, ``> 0``
- ``1e-3``
- --
- Local absolute tolerance for the embedded RKF45 error estimate.
Attributes
----------
V : brainstate.HiddenState
Membrane potential :math:`V_m`. Shape: ``(*in_size,)``.
Units: mV.
m : brainstate.HiddenState
Na activation gating variable (dimensionless). Shape: ``(*in_size,)``.
Range: [0, 1].
h : brainstate.HiddenState
Na inactivation gating variable (dimensionless). Shape: ``(*in_size,)``.
Range: [0, 1].
n : brainstate.HiddenState
K activation gating variable (dimensionless). Shape: ``(*in_size,)``.
Range: [0, 1].
I_syn_ex : brainstate.ShortTermState
Excitatory postsynaptic current :math:`I_{syn,ex}`. Shape: ``(*in_size,)``.
Units: pA.
I_syn_in : brainstate.ShortTermState
Inhibitory postsynaptic current :math:`I_{syn,in}`. Shape: ``(*in_size,)``.
Units: pA.
dI_syn_ex : brainstate.ShortTermState
Excitatory alpha-kernel derivative state. Shape: ``(*in_size,)``.
Units: pA/ms.
dI_syn_in : brainstate.ShortTermState
Inhibitory alpha-kernel derivative state. Shape: ``(*in_size,)``.
Units: pA/ms.
I_stim : brainstate.ShortTermState
One-step delayed external current buffer. Shape: ``(*in_size,)``.
Units: pA.
refractory_step_count : brainstate.ShortTermState
Remaining refractory steps. Shape: ``(*in_size,)``. Dtype: int32.
integration_step : brainstate.ShortTermState
Persistent RKF45 substep size estimate (ms).
last_spike_time : brainstate.ShortTermState
Time of most recent spike emission. Shape: ``(*in_size,)``.
Units: ms. Updated to ``t + dt`` on spike emission.
Raises
------
ValueError
If any of the following conditions are violated:
- ``C_m <= 0``
- ``t_ref < 0``
- ``tau_syn_ex <= 0`` or ``tau_syn_in <= 0``
- ``g_Na < 0``, ``g_K < 0``, or ``g_L < 0``
Notes
-----
- Unlike IAF models, the HH model does **not** reset the membrane potential
after a spike. Repolarization occurs naturally through the potassium current.
- During the refractory period, the neuron's subthreshold dynamics continue
to evolve freely; only spike emission is suppressed.
- Spike weights are interpreted as current amplitudes (pA). Positive weights
are excitatory; negative weights are inhibitory.
- The adaptive RK45 integrator evaluates the ODE right-hand side multiple
times per step, so computation cost scales with desired accuracy (controlled
by ``gsl_error_tol``).
- Spike detection combines threshold crossing (0 mV) and local maximum
detection, matching the biological action potential waveform.
References
----------
.. [1] Hodgkin AL, Huxley AF (1952). A quantitative description of membrane
current and its application to conduction and excitation in nerve.
The Journal of Physiology 117:500-544.
DOI: https://doi.org/10.1113/jphysiol.1952.sp004764
.. [2] Gerstner W, Kistler W (2002). Spiking neuron models: Single neurons,
populations, plasticity. Cambridge University Press.
.. [3] Dayan P, Abbott LF (2001). Theoretical neuroscience: Computational
and mathematical modeling of neural systems. MIT Press.
.. [4] NEST Simulator Documentation. hh_psc_alpha neuron model.
https://nest-simulator.readthedocs.io/en/stable/models/hh_psc_alpha.html
See Also
--------
iaf_psc_alpha : Leaky integrate-and-fire with alpha-shaped PSCs.
hh_psc_alpha_clopath : HH neuron with Clopath voltage-based STDP.
hh_psc_alpha_gap : HH neuron with gap junction support.
Examples
--------
Create a single Hodgkin-Huxley neuron and observe spiking behavior under
constant current injection:
.. code-block:: python
>>> import brainstate as bst
>>> import saiunit as u
>>> import brainpy.state as bps
>>> import matplotlib.pyplot as plt
>>> # Initialize simulation context
>>> bst.environ.set(dt=0.1 * u.ms)
>>> # Create neuron
>>> neuron = bps.hh_psc_alpha(in_size=1, I_e=500. * u.pA)
>>> neuron.init_all_states()
>>> # Run simulation
>>> times = []
>>> voltages = []
>>> for _ in range(2000): # 200 ms
... neuron.update()
... times.append(float(bst.environ.get('t') / u.ms))
... voltages.append(float(neuron.V.value / u.mV))
>>> # Plot results
>>> plt.plot(times, voltages)
>>> plt.xlabel('Time (ms)')
>>> plt.ylabel('Membrane potential (mV)')
>>> plt.title('Hodgkin-Huxley neuron spiking')
>>> plt.show()
"""
__module__ = 'brainpy.state'
_MIN_H = 1e-8 * u.ms # ms
_MAX_ITERS = 100000
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -54.402 * u.mV,
C_m: ArrayLike = 100. * u.pF,
g_Na: ArrayLike = 12000. * u.nS,
g_K: ArrayLike = 3600. * u.nS,
g_L: ArrayLike = 30. * u.nS,
E_Na: ArrayLike = 50. * u.mV,
E_K: ArrayLike = -77. * u.mV,
t_ref: ArrayLike = 2. * u.ms,
tau_syn_ex: ArrayLike = 0.2 * u.ms,
tau_syn_in: ArrayLike = 2. * u.ms,
I_e: ArrayLike = 0. * u.pA,
V_m_init: ArrayLike = -65. * u.mV,
Act_m_init: ArrayLike = None,
Inact_h_init: ArrayLike = None,
Act_n_init: ArrayLike = None,
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
gsl_error_tol: float = 1e-3,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.g_Na = braintools.init.param(g_Na, self.varshape)
self.g_K = braintools.init.param(g_K, self.varshape)
self.g_L = braintools.init.param(g_L, self.varshape)
self.E_Na = braintools.init.param(E_Na, self.varshape)
self.E_K = braintools.init.param(E_K, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.V_m_init = V_m_init
self.Act_m_init = Act_m_init
self.Inact_h_init = Inact_h_init
self.Act_n_init = Act_n_init
self.gsl_error_tol = gsl_error_tol
self._validate_parameters()
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=brainstate.environ.get_dt()
)
# other variable
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
def _validate_parameters(self):
r"""Validate parameter constraints at construction time.
Raises
------
ValueError
If ``C_m <= 0``, ``t_ref < 0``, ``tau_syn_ex <= 0``, ``tau_syn_in <= 0``,
or any conductance is negative.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.C_m, self.t_ref, self.g_Na)):
return
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time cannot be negative.')
if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.g_Na < 0.0 * u.nS) or np.any(self.g_K < 0.0 * u.nS) or np.any(
self.g_L < 0.0 * u.nS):
raise ValueError('All conductances must be non-negative.')
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables.
Sets initial values for membrane potential, gating variables, synaptic
currents, refractory counters, and buffers. If gating variable initial
values are not explicitly provided, they are computed at equilibrium for
the given initial membrane potential.
Parameters
----------
**kwargs : dict
Additional keyword arguments (unused, for compatibility).
Notes
-----
State variables initialized:
- ``V``: membrane potential (from ``V_m_init``)
- ``m``, ``h``, ``n``: gating variables (from ``Act_m_init``,
``Inact_h_init``, ``Act_n_init`` if provided; otherwise computed
at equilibrium for ``V_m_init``)
- ``I_syn_ex``, ``I_syn_in``, ``dI_syn_ex``, ``dI_syn_in``: synaptic
states (initialized to zero)
- ``I_stim``: external current buffer (initialized to zero)
- ``refractory_step_count``: refractory countdown (initialized to zero)
- ``integration_step``: persistent RKF45 substep size
- ``last_spike_time``: spike time record (initialized to -1e7 ms)
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
V_init_mV = float(u.get_mantissa(u.math.asarray(self.V_m_init / u.mV)))
# Compute equilibrium gating variables at initial V
m_eq, h_eq, n_eq = _hh_psc_alpha_equilibrium(V_init_mV)
V = braintools.init.param(braintools.init.Constant(self.V_m_init), self.varshape)
if self.Act_m_init is not None:
m_init = float(u.get_mantissa(u.math.asarray(self.Act_m_init)))
else:
m_init = m_eq
if self.Inact_h_init is not None:
h_init = float(u.get_mantissa(u.math.asarray(self.Inact_h_init)))
else:
h_init = h_eq
if self.Act_n_init is not None:
n_init = float(u.get_mantissa(u.math.asarray(self.Act_n_init)))
else:
n_init = n_eq
self.V = brainstate.HiddenState(V)
self.m = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_init), self.varshape)
)
self.h = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(h_init), self.varshape)
)
self.n = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(n_init), self.varshape)
)
zeros_pA_per_ms = u.math.zeros(self.varshape, dtype=dftype) * (u.pA / u.ms)
self.dI_syn_ex = brainstate.ShortTermState(zeros_pA_per_ms)
self.dI_syn_in = brainstate.ShortTermState(zeros_pA_per_ms)
self.I_syn_ex = brainstate.HiddenState(u.math.zeros(self.varshape, dtype=dftype) * u.pA)
self.I_syn_in = brainstate.HiddenState(u.math.zeros(self.varshape, dtype=dftype) * u.pA)
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape)
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate differentiable spike output using surrogate gradient.
Applies the surrogate spike function to the membrane potential scaled
relative to the 0 mV threshold. This enables gradient-based learning
through the spike generation process.
Parameters
----------
V : ArrayLike or None, optional
Membrane potential in mV. If ``None``, uses ``self.V.value``.
Shape must broadcast with ``self.varshape``.
Returns
-------
ArrayLike
Differentiable spike signal with shape ``(*in_size,)``.
Typically near 0 for subthreshold, near 1 for suprathreshold.
Notes
-----
The spike threshold for HH neurons is 0 mV. The input voltage is
scaled relative to this threshold before applying the surrogate function.
"""
V = self.V.value if V is None else V
# For HH neurons, spike threshold is 0 mV. Scale relative to 0 mV.
v_scaled = V / (1. * u.mV)
return self.spk_fun(v_scaled)
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
Parameters
----------
state : DotDict
Keys: V, m, h, n, dI_ex, I_ex, dI_in, I_in — ODE state variables.
extra : DotDict
Keys: spike_mask, r, V_old, i_stim — mutable auxiliary data
carried through the integrator.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
V = state.V
m_ = state.m
h_ = state.h
n_ = state.n
dI_ex = state.dI_ex
I_ex = state.I_ex
dI_in = state.dI_in
I_in = state.I_in
# Voltage in mV for rate functions (unitless computation)
V_mV = V / u.mV
alpha_n = (0.01 * (V_mV + 55.0)) / (1.0 - u.math.exp(-(V_mV + 55.0) / 10.0))
beta_n = 0.125 * u.math.exp(-(V_mV + 65.0) / 80.0)
alpha_m = (0.1 * (V_mV + 40.0)) / (1.0 - u.math.exp(-(V_mV + 40.0) / 10.0))
beta_m = 4.0 * u.math.exp(-(V_mV + 65.0) / 18.0)
alpha_h = 0.07 * u.math.exp(-(V_mV + 65.0) / 20.0)
beta_h = 1.0 / (1.0 + u.math.exp(-(V_mV + 35.0) / 10.0))
# Ionic currents (nS * mV = pA)
I_Na = self.g_Na * m_ * m_ * m_ * h_ * (V - self.E_Na)
I_K = self.g_K * n_ * n_ * n_ * n_ * (V - self.E_K)
I_L = self.g_L * (V - self.E_L)
# Membrane voltage derivative: dV/dt = (-(I_Na + I_K + I_L) + I_stim + I_e + I_ex + I_in) / C_m
dV = (-(I_Na + I_K + I_L) + extra.i_stim + self.I_e + I_ex + I_in) / self.C_m
# Gating variable derivatives (rates are in 1/ms)
dm = (alpha_m * (1.0 - m_) - beta_m * m_) / u.ms
dh = (alpha_h * (1.0 - h_) - beta_h * h_) / u.ms
dn = (alpha_n * (1.0 - n_) - beta_n * n_) / u.ms
# Alpha-kernel synaptic current derivatives
ddI_ex = -dI_ex / self.tau_syn_ex
dI_ex_dt = dI_ex - I_ex / self.tau_syn_ex
ddI_in = -dI_in / self.tau_syn_in
dI_in_dt = dI_in - I_in / self.tau_syn_in
return DotDict(V=dV, m=dm, h=dh, n=dn, dI_ex=ddI_ex, I_ex=dI_ex_dt, dI_in=ddI_in, I_in=dI_in_dt)
def _event_fn(self, state, extra, accept):
"""In-loop spike detection using threshold-and-local-maximum criterion.
For the HH model, spike detection occurs *after* integration in the
update method, not inside the integration loop. This event function
tracks V_old for local maximum detection but does not perform
spike/reset inside the loop (since HH has no voltage reset).
Parameters
----------
state : DotDict
Keys: V, m, h, n, dI_ex, I_ex, dI_in, I_in — ODE state variables.
extra : DotDict
Keys: spike_mask, r, V_old, i_stim.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(new_state, new_extra) DotDicts with updated auxiliary info.
"""
# Update V_old to track the previous accepted voltage for local max detection
new_V_old = u.math.where(accept, state.V, extra.V_old)
new_extra = DotDict({**extra, 'V_old': new_V_old})
return state, new_extra
[docs]
def update(self, x=0. * u.pA):
r"""Update neuron state for one simulation step.
Integrates the full Hodgkin-Huxley dynamics over one time step :math:`dt`,
applies synaptic inputs, detects spikes using threshold-and-local-maximum
criterion, updates refractory state, and buffers external current for the
next step. Follows NEST ``hh_psc_alpha`` update order exactly.
**Update Order:**
1. Record pre-integration membrane potential (``V_old``).
2. Integrate the 8-dimensional ODE system
:math:`(V_m, m, h, n, dI_{ex}, I_{ex}, dI_{in}, I_{in})` over
:math:`[t, t+dt]` using adaptive RK45 (Dormand-Prince).
3. Add arriving synaptic spike inputs to :math:`dI_{syn,ex}` and
:math:`dI_{syn,in}`.
4. Check spike condition:
``refractory_step_count == 0 and V_m >= 0 and V_old > V_m``.
5. Update refractory counter and record spike time.
6. Store buffered external stimulation current ``x`` for next step.
Parameters
----------
x : ArrayLike, optional
External stimulation current input in pA (in addition to ``I_e``).
Shape must broadcast with ``(*in_size,)``.
Default is ``0. * u.pA``.
Returns
-------
ArrayLike
Differentiable spike output with shape ``(*in_size,)``.
Generated by applying ``self.spk_fun`` to the spike condition.
Near 1 when spike detected, near 0 otherwise.
Notes
-----
- The external current ``x`` is buffered for one step via ``I_stim``,
matching NEST's ring-buffer semantics. Current provided at step
:math:`n` affects dynamics at step :math:`n+1`.
- Spike weights are collected via ``sum_delta_inputs(0*pA)`` and split
by sign: positive weights drive excitatory state, negative weights
drive inhibitory state.
- During the refractory period, all state variables evolve freely;
only spike emission is suppressed.
- Spike detection combines threshold crossing (0 mV) and local maximum
detection (``V_old > V_m``) to match biological action potential
characteristics.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Read state variables with their natural units.
V = self.V.value # mV
m = self.m.value # dimensionless
h_val = self.h.value # dimensionless
n = self.n.value # dimensionless
dI_ex = self.dI_syn_ex.value # pA/ms
I_ex = self.I_syn_ex.value # pA
dI_in = self.dI_syn_in.value # pA/ms
I_in = self.I_syn_in.value # pA
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
h_step = self.integration_step.value # ms
# Record V before integration for spike detection
V_old = V
# Current input for next step (one-step delay).
new_i_stim = self.sum_current_inputs(x, self.V.value) # pA
# Adaptive RKF45 integration via generic integrator.
ode_state = DotDict(V=V, m=m, h=h_val, n=n, dI_ex=dI_ex, I_ex=I_ex, dI_in=dI_in, I_in=I_in)
extra = DotDict(
spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_),
r=r,
V_old=V_old,
i_stim=i_stim,
)
ode_state, h_step, extra = self.integrator(state=ode_state, h=h_step, extra=extra)
V = ode_state.V
m = ode_state.m
h_val = ode_state.h
n = ode_state.n
dI_ex = ode_state.dI_ex
I_ex = ode_state.I_ex
dI_in = ode_state.dI_in
I_in = ode_state.I_in
# Synaptic spike inputs (applied after integration, matching NEST).
w_all = self.sum_delta_inputs(0. * u.pA)
w_ex = u.math.where(w_all > 0.0 * u.pA, w_all, 0.0 * u.pA)
w_in = u.math.where(w_all < 0.0 * u.pA, w_all, 0.0 * u.pA)
# PSC normalization: e / tau ensures peak current = weight for weight=1.
pscon_ex = np.e / self.tau_syn_ex # 1/ms
pscon_in = np.e / self.tau_syn_in # 1/ms
# Apply synaptic spike inputs.
dI_ex = dI_ex + pscon_ex * w_ex # pA/ms + 1/ms * pA = pA/ms
dI_in = dI_in + pscon_in * w_in # pA/ms + 1/ms * pA = pA/ms
# Spike detection: threshold crossing + local maximum
not_refractory = r == 0
crossed_threshold = V >= 0.0 * u.mV
local_max = V_old > V
spike_mask = not_refractory & crossed_threshold & local_max
# Refractory update
r_new = u.math.where(spike_mask, self.ref_count, u.math.where(r > 0, r - 1, r))
# Write back state.
self.V.value = V
self.m.value = m
self.h.value = h_val
self.n.value = n
self.dI_syn_ex.value = dI_ex
self.I_syn_ex.value = I_ex
self.dI_syn_in.value = dI_in
self.I_syn_in.value = I_in
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r_new), dtype=ditype)
self.integration_step.value = h_step
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
return u.math.asarray(spike_mask, dtype=dftype)