# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep
__all__ = [
'hh_psc_alpha_clopath',
]
def _hh_psc_alpha_clopath_equilibrium(V):
r"""Compute equilibrium values of Hodgkin-Huxley gating variables for Clopath model.
Calculates steady-state activation and inactivation at a given membrane
potential using the voltage-dependent rate functions from NEST's
``hh_psc_alpha_clopath`` model. These equilibria are used for state initialization
when explicit gating values are not provided.
Parameters
----------
V : float
Membrane potential in mV (unitless, not a ``saiunit`` quantity).
Returns
-------
tuple of float
``(m_inf, h_inf, n_inf)`` — equilibrium values (dimensionless) for
Na activation, Na inactivation, and K activation, respectively.
Each is in [0, 1].
Notes
-----
Uses the Hodgkin-Huxley rate functions with sign conventions matching
NEST's implementation:
.. math::
\alpha_n = \frac{0.01(V + 55)}{1 - e^{-(V+55)/10}}, \quad
\beta_n = 0.125 e^{-(V+65)/80}
.. math::
\alpha_m = \frac{0.1(V + 40)}{1 - e^{-(V+40)/10}}, \quad
\beta_m = 4 e^{-(V+65)/18}
.. math::
\alpha_h = 0.07 e^{-(V+65)/20}, \quad
\beta_h = \frac{1}{1 + e^{-(V+35)/10}}
Equilibrium is :math:`x_\infty = \alpha_x / (\alpha_x + \beta_x)`.
"""
alpha_n = (0.01 * (V + 55.0)) / (1.0 - np.exp(-(V + 55.0) / 10.0))
beta_n = 0.125 * np.exp(-(V + 65.0) / 80.0)
alpha_m = (0.1 * (V + 40.0)) / (1.0 - np.exp(-(V + 40.0) / 10.0))
beta_m = 4.0 * np.exp(-(V + 65.0) / 18.0)
alpha_h = 0.07 * np.exp(-(V + 65.0) / 20.0)
beta_h = 1.0 / (1.0 + np.exp(-(V + 35.0) / 10.0))
m_inf = alpha_m / (alpha_m + beta_m)
h_inf = alpha_h / (alpha_h + beta_h)
n_inf = alpha_n / (alpha_n + beta_n)
return m_inf, h_inf, n_inf
class hh_psc_alpha_clopath(NESTNeuron):
r"""NEST-compatible Hodgkin-Huxley neuron with Clopath plasticity support.
Current-based spiking neuron using the Hodgkin-Huxley formalism with
voltage-gated sodium and potassium channels, leak conductance, alpha-function
postsynaptic currents, threshold-and-local-maximum spike detection, and three
additional low-pass filtered voltage traces for Clopath voltage-based STDP.
Follows NEST ``models/hh_psc_alpha_clopath.{h,cpp}`` implementation with
adaptive Runge-Kutta integration (RK45).
**1. Mathematical Model**
**Membrane and ionic current dynamics:**
The membrane potential evolves as
.. math::
C_m \frac{dV_m}{dt} = -(I_{Na} + I_K + I_L) + I_{stim} + I_e
+ I_{syn,ex} + I_{syn,in}
where
.. math::
I_{Na} &= g_{Na}\, m^3\, h\, (V_m - E_{Na}) \\
I_K &= g_K\, n^4\, (V_m - E_K) \\
I_L &= g_L\, (V_m - E_L)
Gating variables :math:`m` (Na activation), :math:`h` (Na inactivation),
:math:`n` (K activation) obey first-order kinetics:
.. math::
\frac{dx}{dt} = \alpha_x(V)(1 - x) - \beta_x(V)\,x
with voltage-dependent rate functions (voltage :math:`V` in mV, rates in 1/ms):
.. math::
\alpha_n &= \frac{0.01\,(V + 55)}{1 - e^{-(V+55)/10}}, \quad
\beta_n = 0.125\,e^{-(V+65)/80} \\
\alpha_m &= \frac{0.1\,(V + 40)}{1 - e^{-(V+40)/10}}, \quad
\beta_m = 4\,e^{-(V+65)/18} \\
\alpha_h &= 0.07\,e^{-(V+65)/20}, \quad
\beta_h = \frac{1}{1 + e^{-(V+35)/10}}
**Clopath low-pass filtered voltage traces:**
The model extends standard ``hh_psc_alpha`` with three additional state
variables for Clopath voltage-based plasticity:
.. math::
\frac{d\bar{u}_+}{dt} &= \frac{-\bar{u}_+ + V_m}{\tau_{\bar{u}_+}} \\
\frac{d\bar{u}_-}{dt} &= \frac{-\bar{u}_- + V_m}{\tau_{\bar{u}_-}} \\
\frac{d\bar{\bar{u}}}{dt} &= \frac{-\bar{\bar{u}} + \bar{u}_-}{\tau_{\bar{\bar{u}}}}
- :math:`\bar{u}_+` (``u_bar_plus``) is a slow-filtered voltage with time
constant :math:`\tau_{\bar{u}_+} = 114` ms, used for LTP induction.
- :math:`\bar{u}_-` (``u_bar_minus``) is a fast-filtered voltage with time
constant :math:`\tau_{\bar{u}_-} = 10` ms, used for LTD induction.
- :math:`\bar{\bar{u}}` (``u_bar_bar``) is a second-stage slow filter of
:math:`\bar{u}_-` with time constant :math:`\tau_{\bar{\bar{u}}} = 500` ms,
used for homeostatic sliding threshold in the Clopath rule.
These traces are integrated as part of the same 11-dimensional ODE system
and are accessible to connected Clopath synapse models for computing
voltage-dependent weight updates.
**Alpha-function synaptic currents:**
Each synapse type (excitatory/inhibitory) is modelled as a second-order
linear system producing an alpha-shaped postsynaptic current:
.. math::
\frac{dI_{syn}}{dt} &= dI_{syn} - \frac{I_{syn}}{\tau_{syn}} \\
\frac{d(dI_{syn})}{dt} &= -\frac{dI_{syn}}{\tau_{syn}}
A spike arriving with weight :math:`w` (in pA) adds
:math:`w \cdot e / \tau_{syn}` to :math:`dI_{syn}`, normalizing the
peak current to :math:`w` for :math:`w = 1`. Incoming spike weights are
split by sign: positive weights drive excitatory state (:math:`dI_{syn,ex}`),
negative weights drive inhibitory state (:math:`dI_{syn,in}`).
**2. Spike Detection and Refractory Handling**
A spike is detected when the membrane potential crosses 0 mV from below
**and** a local maximum is detected (i.e., the potential starts decreasing).
Formally, a spike is emitted when:
1. ``refractory_step_count == 0`` (not in refractory period), **and**
2. ``V_m >= 0 mV`` (threshold crossing), **and**
3. ``V_old > V_m`` (local maximum — potential is now falling).
Unlike integrate-and-fire models, **no voltage reset occurs**. The potassium
current naturally repolarizes the membrane after a spike. During the
refractory period :math:`t_{ref}`, spike emission is suppressed but all
state variables (including the Clopath filtered voltages) continue evolving
according to their differential equations.
**3. Update Order Per Simulation Step**
The update follows NEST's exact order:
1. Record pre-integration membrane potential (``V_old``).
2. Integrate the full 11-dimensional ODE system
:math:`(V_m, m, h, n, dI_{ex}, I_{ex}, dI_{in}, I_{in}, \bar{u}_+, \bar{u}_-, \bar{\bar{u}})`
over one time step :math:`[t, t+dt]` using adaptive RK45 (Dormand-Prince).
3. Add arriving synaptic spike inputs to :math:`dI_{syn,ex}` and
:math:`dI_{syn,in}`.
4. Check spike condition: ``V_m >= 0 and V_old > V_m and r == 0``.
5. Update refractory counter and record spike time.
6. Store buffered external stimulation current for the next step.
**4. Numerical Integration**
Uses a JAX-based adaptive RKF45 integrator via
:class:`~brainpy_state._nest._utils.AdaptiveRungeKuttaStep` to match
NEST's GSL RKF45 adaptive integrator. Default tolerance is
``gsl_error_tol=1e-6``. All neurons are integrated simultaneously in a
vectorized fashion.
**5. Assumptions, Constraints, and Computational Implications**
- ``C_m > 0``, ``g_Na >= 0``, ``g_K >= 0``, ``g_L >= 0``,
``tau_syn_ex > 0``, ``tau_syn_in > 0``, ``tau_u_bar_plus > 0``,
``tau_u_bar_minus > 0``, ``tau_u_bar_bar > 0``, and ``t_ref >= 0``
are enforced at construction.
- External current ``update(x=...)`` is buffered for one step, matching
NEST ring-buffer semantics.
- The adaptive RKF45 integrator performs vectorized integration across
all neurons simultaneously, enabling efficient GPU acceleration.
- Spike detection uses a local maximum criterion rather than threshold
crossing alone, matching biological action potential dynamics.
- The three Clopath voltage traces add computational overhead (~27% increase
in ODE dimensions compared to ``hh_psc_alpha``), but enable voltage-based
plasticity without requiring additional post-hoc filtering.
Parameters
----------
in_size : Size
Population shape specification. All per-neuron parameters are broadcast
to ``self.varshape`` derived from ``in_size``.
E_L : ArrayLike, optional
Leak reversal potential :math:`E_L` in mV; scalar or array broadcastable
to ``self.varshape``. Determines resting potential. Default is
``-54.402 * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF; broadcastable to ``self.varshape``
and strictly positive. Default is ``100. * u.pF``.
g_Na : ArrayLike, optional
Sodium peak conductance :math:`g_{Na}` in nS; broadcastable to
``self.varshape`` and non-negative. Default is ``12000. * u.nS``.
g_K : ArrayLike, optional
Potassium peak conductance :math:`g_K` in nS; broadcastable to
``self.varshape`` and non-negative. Default is ``3600. * u.nS``.
g_L : ArrayLike, optional
Leak conductance :math:`g_L` in nS; broadcastable to ``self.varshape``
and non-negative. Default is ``30. * u.nS``.
E_Na : ArrayLike, optional
Sodium reversal potential :math:`E_{Na}` in mV; broadcastable to
``self.varshape``. Default is ``50. * u.mV``.
E_K : ArrayLike, optional
Potassium reversal potential :math:`E_K` in mV; broadcastable to
``self.varshape``. Default is ``-77. * u.mV``.
t_ref : ArrayLike, optional
Absolute refractory period :math:`t_{ref}` in ms; broadcastable to
``self.varshape`` and non-negative. Converted to integer step counts by
``ceil(t_ref / dt)``. Default is ``2. * u.ms``.
tau_syn_ex : ArrayLike, optional
Excitatory alpha time constant :math:`\tau_{syn,ex}` in ms; broadcastable
to ``self.varshape`` and strictly positive. Default is ``0.2 * u.ms``.
tau_syn_in : ArrayLike, optional
Inhibitory alpha time constant :math:`\tau_{syn,in}` in ms; broadcastable
to ``self.varshape`` and strictly positive. Default is ``2. * u.ms``.
I_e : ArrayLike, optional
Constant injected current :math:`I_e` in pA; scalar or array broadcastable
to ``self.varshape``. Default is ``0. * u.pA``.
tau_u_bar_plus : ArrayLike, optional
Time constant :math:`\tau_{\bar{u}_+}` in ms for slow voltage filter
:math:`\bar{u}_+` (used in Clopath LTP); broadcastable to ``self.varshape``
and strictly positive. Default is ``114. * u.ms``.
tau_u_bar_minus : ArrayLike, optional
Time constant :math:`\tau_{\bar{u}_-}` in ms for fast voltage filter
:math:`\bar{u}_-` (used in Clopath LTD); broadcastable to ``self.varshape``
and strictly positive. Default is ``10. * u.ms``.
tau_u_bar_bar : ArrayLike, optional
Time constant :math:`\tau_{\bar{\bar{u}}}` in ms for second-stage slow
filter :math:`\bar{\bar{u}}` (used in Clopath homeostatic threshold);
broadcastable to ``self.varshape`` and strictly positive. Default is
``500. * u.ms``.
V_m_init : ArrayLike, optional
Initial membrane potential in mV; broadcastable to ``self.varshape``.
Default is ``-65. * u.mV``.
Act_m_init : ArrayLike or None, optional
Initial Na activation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
Inact_h_init : ArrayLike or None, optional
Initial Na inactivation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
Act_n_init : ArrayLike or None, optional
Initial K activation gating variable (dimensionless, range [0,1]);
broadcastable to ``self.varshape``. If ``None``, initialized to
equilibrium value at ``V_m_init``. Default is ``None``.
u_bar_plus_init : ArrayLike, optional
Initial value for :math:`\bar{u}_+` in mV; broadcastable to
``self.varshape``. Default is ``0. * u.mV``.
u_bar_minus_init : ArrayLike, optional
Initial value for :math:`\bar{u}_-` in mV; broadcastable to
``self.varshape``. Default is ``0. * u.mV``.
u_bar_bar_init : ArrayLike, optional
Initial value for :math:`\bar{\bar{u}}` in mV; broadcastable to
``self.varshape``. Default is ``0. * u.mV``.
gsl_error_tol : ArrayLike, optional
Unitless local RKF45 error tolerance, broadcastable and strictly positive.
Default is ``1e-6``.
spk_fun : Callable, optional
Surrogate spike nonlinearity used by :meth:`get_spike` for differentiable
spike generation. Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
``'hard'`` applies stop-gradient to match NEST hard reset semantics.
Default is ``'hard'``.
name : str or None, optional
Optional node name for debugging and visualization.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 27 14 16 36
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar/tuple
- required
- --
- Defines neuron population shape ``self.varshape``.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-54.402 * u.mV``
- :math:`E_L`
- Leak reversal potential (resting potential).
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``100. * u.pF``
- :math:`C_m`
- Membrane capacitance.
* - ``g_Na``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``12000. * u.nS``
- :math:`g_{Na}`
- Sodium peak conductance.
* - ``g_K``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``3600. * u.nS``
- :math:`g_K`
- Potassium peak conductance.
* - ``g_L``
- ArrayLike, broadcastable (nS), ``>= 0``
- ``30. * u.nS``
- :math:`g_L`
- Leak conductance.
* - ``E_Na``
- ArrayLike, broadcastable (mV)
- ``50. * u.mV``
- :math:`E_{Na}`
- Sodium reversal potential.
* - ``E_K``
- ArrayLike, broadcastable (mV)
- ``-77. * u.mV``
- :math:`E_K`
- Potassium reversal potential.
* - ``t_ref``
- ArrayLike, broadcastable (ms), ``>= 0``
- ``2. * u.ms``
- :math:`t_{ref}`
- Absolute refractory period duration.
* - ``tau_syn_ex``
- ArrayLike, broadcastable (ms), ``> 0``
- ``0.2 * u.ms``
- :math:`\tau_{syn,ex}`
- Excitatory alpha-kernel time constant.
* - ``tau_syn_in``
- ArrayLike, broadcastable (ms), ``> 0``
- ``2. * u.ms``
- :math:`\tau_{syn,in}`
- Inhibitory alpha-kernel time constant.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant external current added every step.
* - ``tau_u_bar_plus``
- ArrayLike, broadcastable (ms), ``> 0``
- ``114. * u.ms``
- :math:`\tau_{\bar{u}_+}`
- Time constant for slow voltage filter (Clopath LTP).
* - ``tau_u_bar_minus``
- ArrayLike, broadcastable (ms), ``> 0``
- ``10. * u.ms``
- :math:`\tau_{\bar{u}_-}`
- Time constant for fast voltage filter (Clopath LTD).
* - ``tau_u_bar_bar``
- ArrayLike, broadcastable (ms), ``> 0``
- ``500. * u.ms``
- :math:`\tau_{\bar{\bar{u}}}`
- Time constant for second-stage filter (Clopath homeostasis).
* - ``V_m_init``
- ArrayLike, broadcastable (mV)
- ``-65. * u.mV``
- --
- Initial membrane potential.
* - ``Act_m_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial Na activation; ``None`` uses equilibrium at ``V_m_init``.
* - ``Inact_h_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial Na inactivation; ``None`` uses equilibrium at ``V_m_init``.
* - ``Act_n_init``
- ArrayLike or ``None``, dimensionless
- ``None``
- --
- Initial K activation; ``None`` uses equilibrium at ``V_m_init``.
* - ``u_bar_plus_init``
- ArrayLike, broadcastable (mV)
- ``0. * u.mV``
- --
- Initial value for :math:`\bar{u}_+`.
* - ``u_bar_minus_init``
- ArrayLike, broadcastable (mV)
- ``0. * u.mV``
- --
- Initial value for :math:`\bar{u}_-`.
* - ``u_bar_bar_init``
- ArrayLike, broadcastable (mV)
- ``0. * u.mV``
- --
- Initial value for :math:`\bar{\bar{u}}`.
* - ``gsl_error_tol``
- ArrayLike, broadcastable, unitless, ``> 0``
- ``1e-6``
- --
- Local absolute tolerance for the embedded RKF45 error estimate.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate gradient function for spike generation.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset mode; ``'hard'`` stops gradient through reset.
Attributes
----------
V : brainstate.HiddenState
Membrane potential :math:`V_m`. Shape: ``self.varshape``.
Units: mV.
m : brainstate.HiddenState
Na activation gating variable (dimensionless). Shape: ``self.varshape``.
Range: [0, 1].
h : brainstate.HiddenState
Na inactivation gating variable (dimensionless). Shape: ``self.varshape``.
Range: [0, 1].
n : brainstate.HiddenState
K activation gating variable (dimensionless). Shape: ``self.varshape``.
Range: [0, 1].
I_syn_ex : brainstate.ShortTermState
Excitatory postsynaptic current :math:`I_{syn,ex}`. Shape: ``self.varshape``.
Units: pA.
I_syn_in : brainstate.ShortTermState
Inhibitory postsynaptic current :math:`I_{syn,in}`. Shape: ``self.varshape``.
Units: pA.
dI_syn_ex : brainstate.ShortTermState
Excitatory alpha-kernel derivative state. Shape: ``self.varshape``.
Units: pA/ms.
dI_syn_in : brainstate.ShortTermState
Inhibitory alpha-kernel derivative state. Shape: ``self.varshape``.
Units: pA/ms.
u_bar_plus : brainstate.HiddenState
Slow-filtered voltage :math:`\bar{u}_+` for Clopath LTP. Shape: ``self.varshape``.
Units: mV.
u_bar_minus : brainstate.HiddenState
Fast-filtered voltage :math:`\bar{u}_-` for Clopath LTD. Shape: ``self.varshape``.
Units: mV.
u_bar_bar : brainstate.HiddenState
Second-stage filtered voltage :math:`\bar{\bar{u}}` for Clopath homeostasis.
Shape: ``self.varshape``. Units: mV.
I_stim : brainstate.ShortTermState
One-step delayed external current buffer. Shape: ``self.varshape``.
Units: pA.
refractory_step_count : brainstate.ShortTermState
Remaining refractory steps. Shape: ``self.varshape``. Dtype: int32.
last_spike_time : brainstate.ShortTermState
Time of most recent spike emission. Shape: ``self.varshape``.
Units: ms. Updated to ``t + dt`` on spike emission.
integration_step : brainstate.ShortTermState
Persistent RKF45 substep size estimate (ms).
Raises
------
ValueError
If any of the following conditions are violated:
- ``C_m <= 0``
- ``t_ref < 0``
- ``tau_syn_ex <= 0`` or ``tau_syn_in <= 0``
- ``tau_u_bar_plus <= 0``, ``tau_u_bar_minus <= 0``, or ``tau_u_bar_bar <= 0``
- ``g_Na < 0``, ``g_K < 0``, or ``g_L < 0``
- ``gsl_error_tol <= 0``
Notes
-----
- Unlike IAF models, the HH model does **not** reset the membrane potential
after a spike. Repolarization occurs naturally through the potassium current.
- During the refractory period, the neuron's subthreshold dynamics continue
to evolve freely; only spike emission is suppressed.
- Spike weights are interpreted as current amplitudes (pA). Positive weights
are excitatory; negative weights are inhibitory.
- The three Clopath-related voltage traces (``u_bar_plus``, ``u_bar_minus``,
``u_bar_bar``) are integrated as part of the same 11-dimensional ODE system,
matching NEST's GSL integration. This adds ~27% computational overhead
compared to ``hh_psc_alpha``.
- The adaptive RKF45 integrator evaluates the ODE right-hand side multiple
times per step, so computation cost scales with desired accuracy (controlled
by ``gsl_error_tol``).
- Spike detection combines threshold crossing (0 mV) and local maximum
detection, matching the biological action potential waveform.
References
----------
.. [1] Hodgkin AL, Huxley AF (1952). A quantitative description of membrane
current and its application to conduction and excitation in nerve.
The Journal of Physiology 117:500-544.
DOI: https://doi.org/10.1113/jphysiol.1952.sp004764
.. [2] Clopath C, Busing L, Vasilaki E, Gerstner W (2010). Connectivity
reflects coding: a model of voltage-based STDP with homeostasis.
Nature Neuroscience 13(3):344-352.
DOI: https://doi.org/10.1038/nn.2479
.. [3] Clopath C, Gerstner W (2010). Voltage and spike timing interact
in STDP -- a unified model. Frontiers in Synaptic Neuroscience 2:25.
DOI: https://doi.org/10.3389/fnsyn.2010.00025
.. [4] Gerstner W, Kistler WM (2002). Spiking neuron models: Single neurons,
populations, plasticity. Cambridge University Press.
.. [5] Dayan P, Abbott LF (2001). Theoretical neuroscience: Computational
and mathematical modeling of neural systems. MIT Press.
See Also
--------
hh_psc_alpha : Hodgkin-Huxley neuron without Clopath plasticity support.
clopath_synapse : Voltage-based STDP synapse model that uses these filtered traces.
Examples
--------
Create a population of HH neurons with Clopath plasticity support:
.. code-block:: python
>>> import brainstate as bst
>>> import brainpy_state as bps
>>> import saiunit as u
>>> bst.environ.set(dt=0.1 * u.ms)
>>> neurons = bps.hh_psc_alpha_clopath(
... in_size=100,
... tau_u_bar_plus=114. * u.ms, # Slow LTP filter
... tau_u_bar_minus=10. * u.ms, # Fast LTD filter
... tau_u_bar_bar=500. * u.ms, # Homeostatic filter
... )
>>> neurons.init_state()
>>> # Simulate with constant current injection
>>> spikes = neurons.update(400. * u.pA)
>>> # Access Clopath voltage traces for plasticity computation
>>> u_plus = neurons.u_bar_plus.value
>>> u_minus = neurons.u_bar_minus.value
>>> u_bar = neurons.u_bar_bar.value
"""
__module__ = 'brainpy.state'
_MIN_H = 1e-8 * u.ms # ms
_MAX_ITERS = 100000
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -54.402 * u.mV,
C_m: ArrayLike = 100. * u.pF,
g_Na: ArrayLike = 12000. * u.nS,
g_K: ArrayLike = 3600. * u.nS,
g_L: ArrayLike = 30. * u.nS,
E_Na: ArrayLike = 50. * u.mV,
E_K: ArrayLike = -77. * u.mV,
t_ref: ArrayLike = 2. * u.ms,
tau_syn_ex: ArrayLike = 0.2 * u.ms,
tau_syn_in: ArrayLike = 2. * u.ms,
I_e: ArrayLike = 0. * u.pA,
tau_u_bar_plus: ArrayLike = 114. * u.ms,
tau_u_bar_minus: ArrayLike = 10. * u.ms,
tau_u_bar_bar: ArrayLike = 500. * u.ms,
V_m_init: ArrayLike = -65. * u.mV,
Act_m_init: ArrayLike = None,
Inact_h_init: ArrayLike = None,
Act_n_init: ArrayLike = None,
u_bar_plus_init: ArrayLike = 0. * u.mV,
u_bar_minus_init: ArrayLike = 0. * u.mV,
u_bar_bar_init: ArrayLike = 0. * u.mV,
gsl_error_tol: ArrayLike = 1e-6,
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.g_Na = braintools.init.param(g_Na, self.varshape)
self.g_K = braintools.init.param(g_K, self.varshape)
self.g_L = braintools.init.param(g_L, self.varshape)
self.E_Na = braintools.init.param(E_Na, self.varshape)
self.E_K = braintools.init.param(E_K, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.tau_u_bar_plus = braintools.init.param(tau_u_bar_plus, self.varshape)
self.tau_u_bar_minus = braintools.init.param(tau_u_bar_minus, self.varshape)
self.tau_u_bar_bar = braintools.init.param(tau_u_bar_bar, self.varshape)
self.V_m_init = V_m_init
self.Act_m_init = Act_m_init
self.Inact_h_init = Inact_h_init
self.Act_n_init = Act_n_init
self.u_bar_plus_init = u_bar_plus_init
self.u_bar_minus_init = u_bar_minus_init
self.u_bar_bar_init = u_bar_bar_init
self.gsl_error_tol = gsl_error_tol
self._validate_parameters()
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=brainstate.environ.get_dt()
)
# other variable
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
def _validate_parameters(self):
r"""Validate parameter constraints at construction time.
Raises
------
ValueError
If any parameter violates physical or numerical constraints:
- ``C_m <= 0``
- ``t_ref < 0``
- ``tau_syn_ex <= 0`` or ``tau_syn_in <= 0``
- ``tau_u_bar_plus <= 0``, ``tau_u_bar_minus <= 0``, or ``tau_u_bar_bar <= 0``
- ``g_Na < 0``, ``g_K < 0``, or ``g_L < 0``
- ``gsl_error_tol <= 0``
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.C_m, self.t_ref, self.g_Na)):
return
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time cannot be negative.')
if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_u_bar_plus <= 0.0 * u.ms) or np.any(
self.tau_u_bar_minus <= 0.0 * u.ms) or np.any(
self.tau_u_bar_bar <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.g_Na < 0.0 * u.nS) or np.any(self.g_K < 0.0 * u.nS) or np.any(
self.g_L < 0.0 * u.nS):
raise ValueError('All conductances must be non-negative.')
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize all neuron state variables.
Creates and initializes the 11-dimensional state vector for each neuron:
membrane potential, three gating variables (m, h, n), two pairs of alpha-kernel
states (excitatory/inhibitory), three Clopath filtered voltages, and auxiliary
states for refractory handling and spike timing.
**Initialization logic:**
- **Membrane potential** (``V``): set to ``V_m_init``.
- **Gating variables** (``m``, ``h``, ``n``): if explicit ``Act_m_init``,
``Inact_h_init``, ``Act_n_init`` are provided, use those values; otherwise,
compute equilibrium values at ``V_m_init`` using rate functions.
- **Alpha-kernel states** (``dI_syn_ex``, ``I_syn_ex``, ``dI_syn_in``,
``I_syn_in``): initialized to zero.
- **Clopath filtered voltages** (``u_bar_plus``, ``u_bar_minus``, ``u_bar_bar``):
set to ``u_bar_plus_init``, ``u_bar_minus_init``, ``u_bar_bar_init`` (default 0 mV).
- **Auxiliary states**: ``I_stim`` set to 0 pA, ``refractory_step_count`` set to 0,
``last_spike_time`` set to -1e7 ms (far past), ``integration_step`` set to dt.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
- Equilibrium gating variables are computed using
:func:`_hh_psc_alpha_clopath_equilibrium` at the scalar value of
``V_m_init[0]`` (first element if ``V_m_init`` is an array).
- Initial Clopath filtered voltages default to 0 mV, matching NEST behavior.
For long-running simulations starting from rest, consider setting these
to ``V_m_init`` to avoid initial transient artifacts in voltage-based
plasticity.
- This method must be called before the first :meth:`update` call.
See Also
--------
_hh_psc_alpha_clopath_equilibrium : Computes equilibrium gating variables.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
V_init_mV = np.asarray(u.math.asarray(
braintools.init.param(self.V_m_init, self.varshape) / u.mV
), dtype=dftype)
V_init_scalar = float(V_init_mV.flat[0]) if V_init_mV.ndim > 0 else float(V_init_mV)
# Compute equilibrium gating variables at initial V
m_eq, h_eq, n_eq = _hh_psc_alpha_clopath_equilibrium(V_init_scalar)
V = braintools.init.param(braintools.init.Constant(self.V_m_init), self.varshape)
if self.Act_m_init is not None:
m_init = float(np.asarray(u.math.asarray(self.Act_m_init), dtype=dftype))
else:
m_init = m_eq
if self.Inact_h_init is not None:
h_init = float(np.asarray(u.math.asarray(self.Inact_h_init), dtype=dftype))
else:
h_init = h_eq
if self.Act_n_init is not None:
n_init = float(np.asarray(u.math.asarray(self.Act_n_init), dtype=dftype))
else:
n_init = n_eq
# Clopath filtered voltage initial values
u_bar_plus_init_val = braintools.init.param(
braintools.init.Constant(self.u_bar_plus_init), self.varshape
)
u_bar_minus_init_val = braintools.init.param(
braintools.init.Constant(self.u_bar_minus_init), self.varshape
)
u_bar_bar_init_val = braintools.init.param(
braintools.init.Constant(self.u_bar_bar_init), self.varshape
)
zeros_pA = u.math.zeros(self.varshape, dtype=dftype) * u.pA
zeros_pA_per_ms = u.math.zeros(self.varshape, dtype=dftype) * (u.pA / u.ms)
self.V = brainstate.HiddenState(V)
self.m = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(m_init), self.varshape)
)
self.h = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(h_init), self.varshape)
)
self.n = brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(n_init), self.varshape)
)
self.dI_syn_ex = brainstate.ShortTermState(zeros_pA_per_ms)
self.I_syn_ex = brainstate.ShortTermState(zeros_pA)
self.dI_syn_in = brainstate.ShortTermState(zeros_pA_per_ms)
self.I_syn_in = brainstate.ShortTermState(zeros_pA)
self.u_bar_plus = brainstate.HiddenState(u_bar_plus_init_val)
self.u_bar_minus = brainstate.HiddenState(u_bar_minus_init_val)
self.u_bar_bar = brainstate.HiddenState(u_bar_bar_init_val)
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape)
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate differentiable spike output via surrogate gradient function.
Applies the surrogate spike function ``self.spk_fun`` to the membrane
potential, producing a differentiable approximation of the Heaviside
step function for gradient-based learning. For HH neurons, the spike
threshold is 0 mV.
**Usage in training:**
- The actual spike detection in :meth:`update` uses the discrete threshold-
and-local-maximum criterion (non-differentiable).
- This method provides a **separate**, differentiable spike signal for
backpropagation through time (BPTT) or surrogate gradient learning.
- The returned values do **not** affect the neuron dynamics; they are purely
for gradient computation.
Parameters
----------
V : ArrayLike or None, optional
Membrane potential in mV; scalar or array broadcastable to state shape.
If ``None`` (default), uses the current state ``self.V.value``.
Returns
-------
ArrayLike
Differentiable spike-like signal with shape matching ``V``. Output
range depends on ``self.spk_fun``; for ``ReluGrad()``, positive values
indicate suprathreshold activity, with gradient flowing through the
ReLU derivative at the threshold (0 mV).
Notes
-----
- The membrane potential ``V`` is scaled to be unitless before applying
``self.spk_fun``, as surrogate functions expect dimensionless inputs.
- Common surrogate functions include:
- ``braintools.surrogate.ReluGrad()``: piecewise linear, fast.
- ``braintools.surrogate.Sigmoid()``: smooth, symmetric.
- ``braintools.surrogate.ATan()``: unbounded, soft.
- For inference (non-training), use the boolean spike array from
:meth:`update` thresholded at 0 instead of this method.
Examples
--------
Compute differentiable spike output for a given voltage:
.. code-block:: python
>>> import brainpy_state as bps
>>> import saiunit as u
>>> neurons = bps.hh_psc_alpha_clopath(in_size=10)
>>> neurons.init_state()
>>> V_test = u.math.array([[-70., -10., 0., 5., 20.]]) * u.mV
>>> spikes_surrogate = neurons.get_spike(V_test)
>>> print(spikes_surrogate) # Differentiable approximation
"""
V = self.V.value if V is None else V
# For HH neurons, spike threshold is 0 mV. Scale relative to 0 mV.
v_scaled = V / (1. * u.mV)
return self.spk_fun(v_scaled)
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
Computes the 11-dimensional ODE right-hand side for the Hodgkin-Huxley
model with Clopath voltage traces.
Parameters
----------
state : DotDict
Keys: V, m, h, n, dI_ex, I_ex, dI_in, I_in, u_bar_plus,
u_bar_minus, u_bar_bar -- ODE state variables.
extra : DotDict
Keys: spike_mask, r, V_old, i_stim -- mutable auxiliary data
carried through the integrator.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
V = state.V
m_ = state.m
h_ = state.h
n_ = state.n
# Voltage-dependent rate functions (V in mV, rates in 1/ms)
V_mV = V / u.mV
alpha_n = (0.01 * (V_mV + 55.0)) / (1.0 - u.math.exp(-(V_mV + 55.0) / 10.0)) / u.ms
beta_n = 0.125 * u.math.exp(-(V_mV + 65.0) / 80.0) / u.ms
alpha_m = (0.1 * (V_mV + 40.0)) / (1.0 - u.math.exp(-(V_mV + 40.0) / 10.0)) / u.ms
beta_m = 4.0 * u.math.exp(-(V_mV + 65.0) / 18.0) / u.ms
alpha_h = 0.07 * u.math.exp(-(V_mV + 65.0) / 20.0) / u.ms
beta_h = 1.0 / (1.0 + u.math.exp(-(V_mV + 35.0) / 10.0)) / u.ms
# Ionic currents
I_Na = self.g_Na * m_ * m_ * m_ * h_ * (V - self.E_Na)
I_K = self.g_K * n_ * n_ * n_ * n_ * (V - self.E_K)
I_L = self.g_L * (V - self.E_L)
# Membrane voltage dynamics
dV = (-(I_Na + I_K + I_L) + extra.i_stim + self.I_e + state.I_ex + state.I_in) / self.C_m
# Gating variable dynamics
dm = alpha_m * (1.0 - m_) - beta_m * m_
dh = alpha_h * (1.0 - h_) - beta_h * h_
dn = alpha_n * (1.0 - n_) - beta_n * n_
# Alpha-kernel synaptic current dynamics
ddI_ex = -state.dI_ex / self.tau_syn_ex
dI_ex_dt = state.dI_ex - state.I_ex / self.tau_syn_ex
ddI_in = -state.dI_in / self.tau_syn_in
dI_in_dt = state.dI_in - state.I_in / self.tau_syn_in
# Clopath filtered voltage traces
du_bar_plus = (-state.u_bar_plus + V) / self.tau_u_bar_plus
du_bar_minus = (-state.u_bar_minus + V) / self.tau_u_bar_minus
du_bar_bar = (-state.u_bar_bar + state.u_bar_minus) / self.tau_u_bar_bar
return DotDict(
V=dV, m=dm, h=dh, n=dn,
dI_ex=ddI_ex, I_ex=dI_ex_dt,
dI_in=ddI_in, I_in=dI_in_dt,
u_bar_plus=du_bar_plus, u_bar_minus=du_bar_minus, u_bar_bar=du_bar_bar,
)
def _event_fn(self, state, extra, accept):
"""In-loop spike detection and V_old tracking.
For HH neurons there is no voltage reset inside the integration loop.
This callback records the V_old for post-loop local-maximum spike
detection and tracks spike occurrences.
Parameters
----------
state : DotDict
Keys: V, m, h, n, dI_ex, I_ex, dI_in, I_in, u_bar_plus,
u_bar_minus, u_bar_bar -- ODE state variables.
extra : DotDict
Keys: spike_mask, r, V_old, i_stim.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(new_state, new_extra) DotDicts with updated V_old tracking.
"""
# Update V_old to track the previous accepted voltage for local-max detection
new_V_old = u.math.where(accept, state.V, extra.V_old)
new_extra = DotDict({**extra, 'V_old': new_V_old})
return state, new_extra
[docs]
def update(self, x=0. * u.pA):
r"""Advance the neuron by one simulation step.
Integrates the full 11-dimensional Hodgkin-Huxley dynamics by one time step,
including membrane potential, gating variables, synaptic currents, and
three Clopath filtered voltage traces. Follows NEST's exact update order
for ``hh_psc_alpha_clopath`` with adaptive RK45 integration.
**Update sequence:**
1. Record pre-integration membrane potential (``V_old``) for spike detection.
2. Integrate the full 11-dimensional ODE system
:math:`(V_m, m, h, n, dI_{ex}, I_{ex}, dI_{in}, I_{in}, \bar{u}_+, \bar{u}_-, \bar{\bar{u}})`
over one time step :math:`[t, t+dt]` using adaptive RK45 (Dormand-Prince
method) with tolerance ``gsl_error_tol``.
3. Add arriving synaptic spike inputs to :math:`dI_{syn,ex}` and
:math:`dI_{syn,in}` (spike weights split by sign; positive -> excitatory,
negative -> inhibitory).
4. Check spike condition:
``V_m >= 0 mV`` **and** ``V_old > V_m`` **and** ``refractory_step_count == 0``.
5. Update refractory counter: set to ``ceil(t_ref / dt)`` on spike, otherwise
decrement if positive.
6. Record spike time as ``t + dt`` if spike detected.
7. Store buffered external stimulation current ``x`` for the next step
(one-step delay, matching NEST ring-buffer semantics).
**Integration details:**
- Uses :class:`~brainpy_state._nest._utils.AdaptiveRungeKuttaStep` with
method ``'RKF45'`` for vectorized integration of all neurons simultaneously.
- The ODE right-hand side includes all 11 state equations with full coupling.
- Alpha-kernel normalization ensures a weight of 1 pA produces a peak PSC of 1 pA.
**Spike detection semantics:**
- **No hard reset**: Unlike IAF models, the membrane potential is not clamped
after a spike. The potassium current :math:`I_K` naturally repolarizes the cell.
- **Local maximum criterion**: A spike is only emitted when the voltage
both exceeds 0 mV **and** starts to fall (``V_old > V_m``), matching
biological action potential detection.
- **Refractory suppression**: Spike emission is blocked during the refractory
period, but all state variables (including Clopath filters) continue evolving.
Parameters
----------
x : ArrayLike, optional
External stimulation current in pA; scalar or array broadcastable to
``self.varshape``. Added to ``I_e`` and synaptic currents in the membrane
equation. Buffered for one step (applied in the **next** update call).
Default is ``0. * u.pA``.
Returns
-------
jax.Array
Binary spike tensor with dtype ``jnp.float64`` and shape
``self.V.value.shape``. A value of ``1.0`` indicates at least one
internal spike event occurred during the integrated interval
:math:`(t, t+dt]`.
Notes
-----
- The external current ``x`` is **buffered**: the current passed in step
:math:`t` affects the dynamics at step :math:`t+1`. This matches NEST's
ring-buffer semantics for device input.
- Delta inputs (spike-driven) and current inputs (continuous) are summed via
:meth:`sum_delta_inputs` and :meth:`sum_current_inputs` from the
:class:`~brainpy_state._base.Dynamics` base class.
- Spike weights are interpreted as current amplitudes (pA). To convert from
conductance-based models, multiply weights by driving force.
- The Clopath filtered voltages (``u_bar_plus``, ``u_bar_minus``, ``u_bar_bar``)
are updated automatically as part of the ODE integration. External code
(e.g., Clopath synapse models) can read these values after :meth:`update`
completes.
- Integration is performed with an adaptive vectorized RKF45 loop.
All arithmetic is unit-aware via ``saiunit.math``.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Read state variables with their natural units.
V = self.V.value # mV
m = self.m.value # dimensionless
h_val = self.h.value # dimensionless
n = self.n.value # dimensionless
dI_ex = self.dI_syn_ex.value # pA/ms
I_ex = self.I_syn_ex.value # pA
dI_in = self.dI_syn_in.value # pA/ms
I_in = self.I_syn_in.value # pA
u_bp = self.u_bar_plus.value # mV
u_bm = self.u_bar_minus.value # mV
u_bb = self.u_bar_bar.value # mV
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
h = self.integration_step.value # ms
# Record V_old before integration for spike detection
V_old = V
# Current input for next step (one-step delay).
new_i_stim = self.sum_current_inputs(x, self.V.value) # pA
# Adaptive RKF45 integration via generic integrator.
ode_state = DotDict(
V=V, m=m, h=h_val, n=n,
dI_ex=dI_ex, I_ex=I_ex,
dI_in=dI_in, I_in=I_in,
u_bar_plus=u_bp, u_bar_minus=u_bm, u_bar_bar=u_bb,
)
extra = DotDict(
spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_),
r=r,
V_old=V_old,
i_stim=i_stim,
)
ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra)
V = ode_state.V
m = ode_state.m
h_val = ode_state.h
n = ode_state.n
dI_ex = ode_state.dI_ex
I_ex = ode_state.I_ex
dI_in = ode_state.dI_in
I_in = ode_state.I_in
u_bp = ode_state.u_bar_plus
u_bm = ode_state.u_bar_minus
u_bb = ode_state.u_bar_bar
r = extra.r
# Synaptic spike inputs (applied after integration).
w_ex = self.sum_delta_inputs(u.math.zeros_like(self.I_syn_ex.value), label='w_ex')
w_in = self.sum_delta_inputs(u.math.zeros_like(self.I_syn_in.value), label='w_in')
pscon_ex = np.e / self.tau_syn_ex # 1/ms
pscon_in = np.e / self.tau_syn_in # 1/ms
# Apply synaptic spike inputs.
# w_ex is positive (excitatory magnitude); w_in is positive (inhibitory magnitude,
# negated here to produce a negative dI_in, matching the inhibitory convention).
dI_ex = dI_ex + pscon_ex * w_ex # pA/ms + 1/ms * pA = pA/ms
dI_in = dI_in - pscon_in * w_in # pA/ms - 1/ms * pA = pA/ms (negative = inhibitory)
# Spike detection: threshold crossing + local maximum
# Use V_old (pre-integration voltage) not extra.V_old (integrator-internal),
# matching hh_psc_alpha spike detection logic.
not_refractory = r == 0
crossed_threshold = V >= 0.0 * u.mV
local_max = V_old > V
spike_cond = not_refractory & crossed_threshold & local_max
# Refractory update
r_new = u.math.where(spike_cond, self.ref_count, u.math.where(r > 0, r - 1, r))
# Write back state.
self.V.value = V
self.m.value = m
self.h.value = h_val
self.n.value = n
self.I_syn_ex.value = I_ex
self.I_syn_in.value = I_in
self.dI_syn_ex.value = dI_ex
self.dI_syn_in.value = dI_in
self.u_bar_plus.value = u_bp
self.u_bar_minus.value = u_bm
self.u_bar_bar.value = u_bb
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r_new), dtype=ditype)
self.integration_step.value = h
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spike_cond, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
return u.math.asarray(spike_cond, dtype=dftype)