# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Iterable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep
__all__ = [
'iaf_bw_2001',
]
class iaf_bw_2001(NESTNeuron):
r"""NEST-compatible ``iaf_bw_2001`` neuron model.
Conductance-based leaky integrate-and-fire neuron with AMPA, GABA, and
approximate NMDA synaptic dynamics from Brunel-Wang style cortical models.
This model implements the NEST ``iaf_bw_2001`` neuron with full compatibility,
including adaptive RKF45 integration of subthreshold ODEs, receptor-routed
AMPA/GABA/NMDA spike inputs, one-step delayed external current buffering,
refractory countdown and reset ordering, and NMDA presynaptic jump approximation
using spike-event offsets.
**1. Mathematical Model**
The continuous-time state vector is:
.. math::
y = (V_m, s_{AMPA}, s_{GABA}, s_{NMDA}).
**Membrane dynamics:**
.. math::
C_m \frac{dV_m}{dt} = -g_L(V_m - E_L) - I_{syn} + I_{stim},
where the total synaptic current is:
.. math::
I_{syn} = I_{AMPA} + I_{GABA} + I_{NMDA}.
**Synaptic currents:**
AMPA and GABA currents use Ohmic conductance:
.. math::
I_{AMPA} = (V_m - E_{ex}) s_{AMPA},
\quad
I_{GABA} = (V_m - E_{in}) s_{GABA}.
NMDA current includes voltage-dependent Mg²⁺ block:
.. math::
I_{NMDA} = \frac{(V_m - E_{ex}) s_{NMDA}}
{1 + [Mg^{2+}]\exp(-0.062 V_m)/3.57}.
**Synaptic kinetics:**
All three receptor types decay exponentially:
.. math::
\frac{ds_{AMPA}}{dt} = -\frac{s_{AMPA}}{\tau_{AMPA}},
\quad
\frac{ds_{GABA}}{dt} = -\frac{s_{GABA}}{\tau_{GABA}},
\quad
\frac{ds_{NMDA}}{dt} = -\frac{s_{NMDA}}{\tau_{NMDA,decay}}.
**2. NMDA Approximation and Spike Offsets**
NMDA recurrent coupling uses a presynaptic auxiliary variable ``s_NMDA_pre``
updated only when this neuron spikes. At spike time :math:`t_{spike}`:
.. math::
s_{pre} \leftarrow s_{pre}
\exp\left(-\frac{t_{spike} - t_{last}}{\tau_{NMDA,decay}}\right),
.. math::
\Delta s_{NMDA} = k_0 + k_1 s_{pre},
\quad
s_{pre} \leftarrow s_{pre} + \Delta s_{NMDA},
where the jump constants are:
.. math::
k_1 = \exp(-\alpha\tau_{NMDA,rise}) - 1,
.. math::
k_0 = (\alpha\tau_{NMDA,rise})^{\tau_{NMDA,rise}/\tau_{NMDA,decay}}
\gamma\Big(1 - \tau_{NMDA,rise}/\tau_{NMDA,decay},
\alpha\tau_{NMDA,rise}\Big),
where :math:`\gamma` is the lower incomplete gamma function. The per-spike
:math:`\Delta s_{NMDA}` is exposed as ``spike_offset`` and used by NMDA
receptor events as ``weight * spike_offset`` (matching NEST ``SpikeEvent``
semantics for ``iaf_bw_2001``).
**3. Update Order (NEST Semantics)**
Per simulation step:
1. **Integration**: Integrate ODEs on :math:`(t, t+dt]` using adaptive
Runge-Kutta-Fehlberg 4(5) with persistent internal step size.
2. **Spike reception**: Add arriving AMPA/GABA/NMDA spike increments to
``s_AMPA``, ``s_GABA``, ``s_NMDA``.
3. **Threshold/reset**: Apply refractory countdown or check threshold,
emit spike, and reset if :math:`V_m \geq V_{th}`.
4. **Current buffering**: Store external current into delayed buffer
``I_stim`` for next step (one-step ring-buffer delay).
Ordering notes:
- Refractory clamping is applied after integration (as in NEST source).
- ``I_stim`` uses one-step delay to match NEST's ring-buffer semantics.
- During refractory period, :math:`V_m` is clamped to :math:`V_{reset}`.
**4. Receptor Types and Event Semantics**
Receptor types (matching NEST names and IDs):
- ``AMPA`` = 1 (excitatory, fast)
- ``GABA`` = 2 (inhibitory)
- ``NMDA`` = 3 (excitatory, slow, voltage-dependent)
The ``spike_events`` parameter passed to :meth:`update` may contain tuples
or dictionaries:
- Tuple format: ``(receptor, weight)`` or ``(receptor, weight, offset)``
or ``(receptor, weight, offset, sender_model)``
- Dict format: ``{'receptor_type': ..., 'weight': ..., 'offset': ...,
'sender_model': ...}``
For NMDA events, ``sender_model`` must be ``'iaf_bw_2001'``; otherwise a
``ValueError`` is raised (mirroring NEST's illegal-connection check, as
only ``iaf_bw_2001`` neurons compute the NMDA spike offset).
Registered ``add_delta_input`` entries can be receptor-labeled using
``label='AMPA'``, ``label='GABA'``, or ``label='NMDA'``. Unlabeled delta
inputs default to AMPA.
Parameters
----------
in_size : int, tuple of int
Population shape (number of neurons). Can be an integer or tuple for
multi-dimensional populations.
E_L : saiunit.Quantity, optional
Leak reversal potential. Default: -70 mV.
E_ex : saiunit.Quantity, optional
Excitatory reversal potential (AMPA, NMDA). Default: 0 mV.
E_in : saiunit.Quantity, optional
Inhibitory reversal potential (GABA). Default: -70 mV.
V_th : saiunit.Quantity, optional
Spike threshold potential. Default: -55 mV.
V_reset : saiunit.Quantity, optional
Reset potential after spike. Must be strictly less than ``V_th``.
Default: -60 mV.
C_m : saiunit.Quantity, optional
Membrane capacitance. Must be strictly positive. Default: 500 pF.
g_L : saiunit.Quantity, optional
Leak conductance. Default: 25 nS.
t_ref : saiunit.Quantity, optional
Absolute refractory period duration. Must be non-negative. Default: 2 ms.
tau_AMPA : saiunit.Quantity, optional
AMPA receptor decay time constant. Must be strictly positive. Default: 2 ms.
tau_GABA : saiunit.Quantity, optional
GABA receptor decay time constant. Must be strictly positive. Default: 5 ms.
tau_decay_NMDA : saiunit.Quantity, optional
NMDA receptor slow decay time constant. Must be strictly positive.
Default: 100 ms.
tau_rise_NMDA : saiunit.Quantity, optional
NMDA receptor fast rise time constant for jump approximation. Must be
strictly positive. Default: 2 ms.
alpha : saiunit.Quantity, optional
NMDA jump-shape parameter (rate constant). Must be strictly positive.
Default: 0.5 / ms.
conc_Mg2 : saiunit.Quantity, optional
Extracellular magnesium concentration for NMDA voltage-dependent block.
Must be strictly positive. Default: 1 mM.
gsl_error_tol : float, optional
RKF45 local error tolerance (analog to NEST's ``gsl_error_tol``).
Smaller values increase integration accuracy but decrease performance.
Must be strictly positive. Default: 1e-3.
V_initializer : callable, optional
Membrane potential initializer function. Default: Constant(-70 mV).
s_AMPA_initializer : callable, optional
AMPA conductance state initializer. Default: Constant(0 nS).
s_GABA_initializer : callable, optional
GABA conductance state initializer. Default: Constant(0 nS).
s_NMDA_initializer : callable, optional
NMDA conductance state initializer. Default: Constant(0 nS).
spk_fun : callable, optional
Surrogate gradient function for spike generation. Default: ReluGrad().
spk_reset : str, optional
Spike reset mode. ``'hard'`` (stop gradient) matches NEST behavior;
``'soft'`` (subtract threshold) is differentiable. Default: 'hard'.
ref_var : bool, optional
If True, expose boolean ``refractory`` state variable. Default: False.
name : str, optional
Name of the neuron group.
Parameter Mapping
-----------------
The following table maps brainpy.state parameter names to their NEST equivalents:
==================== =================== ===========================================================
**brainpy.state** **NEST** **Description**
==================== =================== ===========================================================
``E_L`` ``E_L`` Leak reversal potential
``E_ex`` ``E_ex`` Excitatory reversal potential
``E_in`` ``E_in`` Inhibitory reversal potential
``V_th`` ``V_th`` Spike threshold
``V_reset`` ``V_reset`` Reset potential
``C_m`` ``C_m`` Membrane capacitance
``g_L`` ``g_L`` Leak conductance
``t_ref`` ``t_ref`` Refractory period
``tau_AMPA`` ``tau_AMPA`` AMPA decay time constant
``tau_GABA`` ``tau_GABA`` GABA decay time constant
``tau_decay_NMDA`` ``tau_decay_NMDA`` NMDA slow decay time constant
``tau_rise_NMDA`` ``tau_rise_NMDA`` NMDA fast rise time constant
``alpha`` ``alpha`` NMDA jump-shape parameter
``conc_Mg2`` ``conc_Mg2`` Extracellular Mg²⁺ concentration
``gsl_error_tol`` ``gsl_error_tol`` RKF45 error tolerance
==================== =================== ===========================================================
Recordables
-----------
The following state variables can be recorded during simulation:
- ``V_m`` : membrane potential (mV)
- ``s_AMPA`` : AMPA conductance state (nS)
- ``s_GABA`` : GABA conductance state (nS)
- ``s_NMDA`` : NMDA conductance state (nS)
- ``I_AMPA`` : AMPA synaptic current (pA)
- ``I_GABA`` : GABA synaptic current (pA)
- ``I_NMDA`` : NMDA synaptic current (pA)
Additional State Variables
--------------------------
The following internal state variables are maintained but typically not recorded:
- ``s_NMDA_pre`` : presynaptic NMDA helper state (unitless)
- ``spike_offset`` : per-step NMDA offset emitted on spike (unitless)
- ``refractory_step_count`` : absolute refractory countdown (int)
- ``integration_step`` : persistent adaptive RKF45 step size (ms)
- ``I_stim`` : one-step delayed external current buffer (pA)
- ``last_spike_time`` : time of last spike (ms)
- ``refractory`` : boolean refractory indicator (only if ``ref_var=True``)
Raises
------
ValueError
If ``V_reset >= V_th`` (reset must be below threshold).
ValueError
If ``C_m <= 0`` (capacitance must be positive).
ValueError
If ``t_ref < 0`` (refractory period cannot be negative).
ValueError
If any time constant (``tau_AMPA``, ``tau_GABA``, ``tau_decay_NMDA``,
``tau_rise_NMDA``) is non-positive.
ValueError
If ``alpha <= 0`` (NMDA shape parameter must be positive).
ValueError
If ``conc_Mg2 <= 0`` (Mg²⁺ concentration must be positive).
ValueError
If ``gsl_error_tol <= 0`` (error tolerance must be positive).
ValueError
If NMDA spike event has ``sender_model != 'iaf_bw_2001'`` (only
``iaf_bw_2001`` neurons can compute NMDA spike offsets).
Examples
--------
Create a simple network with AMPA and NMDA recurrent connections:
.. code-block:: python
>>> import brainpy.state as bst
>>> import saiunit as u
>>> import brainstate
>>>
>>> # Create neuron population
>>> neurons = bst.iaf_bw_2001(100, V_th=-50*u.mV, t_ref=3*u.ms)
>>>
>>> # Initialize states
>>> with brainstate.environ.context(dt=0.1*u.ms):
... neurons.init_all_states()
>>>
>>> # Simulate with external input
>>> with brainstate.environ.context(dt=0.1*u.ms, t=0*u.ms):
... spike = neurons(x=500*u.pA) # External current input
Simulate with explicit spike events (receptor-routed):
.. code-block:: python
>>> # AMPA spike event (tuple format)
>>> ampa_event = ('AMPA', 1.0*u.nS)
>>> spike = neurons(spike_events=[ampa_event])
>>>
>>> # NMDA spike event (dict format with offset)
>>> nmda_event = {
... 'receptor_type': 'NMDA',
... 'weight': 0.5*u.nS,
... 'offset': 0.8, # Presynaptic NMDA offset from sender
... 'sender_model': 'iaf_bw_2001'
... }
>>> spike = neurons(spike_events=[nmda_event])
Notes
-----
- **Integration method**: This model uses adaptive Runge-Kutta-Fehlberg 4(5)
(RKF45) with local error control, matching NEST's GSL integration. The
internal step size ``integration_step`` is persistent and adapted per neuron.
- **NMDA offset computation**: Only ``iaf_bw_2001`` neurons compute the NMDA
spike offset. If connecting other neuron types, NMDA connections will raise
a ``ValueError``. Use AMPA for inter-model connectivity.
- **Surrogate gradients**: Unlike NEST (which is not differentiable), this
implementation supports gradient-based learning via surrogate spike functions.
- **Performance**: RKF45 integration is accurate but slow for large populations.
For performance-critical applications, consider using fixed-step models
(e.g., ``iaf_cond_exp``, ``iaf_psc_alpha``) when NMDA dynamics are not required.
- **Refractory semantics**: During refractory period, :math:`V_m` is clamped to
:math:`V_{reset}`, and threshold crossing is disabled. This matches NEST behavior.
References
----------
.. [1] Wang X-J (1999). Synaptic basis of cortical persistent activity:
The importance of NMDA receptors to working memory.
Journal of Neuroscience, 19(21):9587-9603.
DOI: https://doi.org/10.1523/JNEUROSCI.19-21-09587.1999
.. [2] Brunel N, Wang X-J (2001). Effects of neuromodulation in a cortical
network model of object working memory dominated by recurrent
inhibition. Journal of Computational Neuroscience, 11(1):63-85.
DOI: https://doi.org/10.1023/A:1011204814320
.. [3] Wang X-J (2002). Probabilistic decision making by slow
reverberation in cortical circuits. Neuron, 36(5):955-968.
DOI: https://doi.org/10.1016/S0896-6273(02)01092-9
.. [4] NEST source: ``models/iaf_bw_2001.h`` and ``models/iaf_bw_2001.cpp``.
See Also
--------
iaf_cond_exp : Simpler conductance-based LIF without NMDA dynamics.
iaf_psc_alpha : Current-based LIF with alpha-function PSCs.
iaf_bw_2001_exact : Exact integration variant (if available).
"""
__module__ = 'brainpy.state'
AMPA = 1
GABA = 2
NMDA = 3
RECEPTOR_TYPES = {
'AMPA': AMPA,
'GABA': GABA,
'NMDA': NMDA,
}
RECORDABLES = (
'V_m',
's_AMPA',
's_GABA',
's_NMDA',
'I_NMDA',
'I_AMPA',
'I_GABA',
)
_MIN_H = 1e-8 * u.ms # ms
_MAX_ITERS = 10000
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
E_ex: ArrayLike = 0. * u.mV,
E_in: ArrayLike = -70. * u.mV,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -60. * u.mV,
C_m: ArrayLike = 500. * u.pF,
g_L: ArrayLike = 25. * u.nS,
t_ref: ArrayLike = 2. * u.ms,
tau_AMPA: ArrayLike = 2. * u.ms,
tau_GABA: ArrayLike = 5. * u.ms,
tau_decay_NMDA: ArrayLike = 100. * u.ms,
tau_rise_NMDA: ArrayLike = 2. * u.ms,
alpha: ArrayLike = 0.5 / u.ms,
conc_Mg2: ArrayLike = 1.0 * u.mM,
gsl_error_tol: ArrayLike = 1e-3,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
s_AMPA_initializer: Callable = braintools.init.Constant(0. * u.nS),
s_GABA_initializer: Callable = braintools.init.Constant(0. * u.nS),
s_NMDA_initializer: Callable = braintools.init.Constant(0. * u.nS),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.E_ex = braintools.init.param(E_ex, self.varshape)
self.E_in = braintools.init.param(E_in, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.g_L = braintools.init.param(g_L, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.tau_AMPA = braintools.init.param(tau_AMPA, self.varshape)
self.tau_GABA = braintools.init.param(tau_GABA, self.varshape)
self.tau_decay_NMDA = braintools.init.param(tau_decay_NMDA, self.varshape)
self.tau_rise_NMDA = braintools.init.param(tau_rise_NMDA, self.varshape)
self.alpha = braintools.init.param(alpha, self.varshape)
self.conc_Mg2 = braintools.init.param(conc_Mg2, self.varshape)
self.gsl_error_tol = gsl_error_tol
self.V_initializer = V_initializer
self.s_AMPA_initializer = s_AMPA_initializer
self.s_GABA_initializer = s_GABA_initializer
self.s_NMDA_initializer = s_NMDA_initializer
self.ref_var = ref_var
self._validate_parameters()
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=brainstate.environ.get_dt()
)
# other variable
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
# Pre-compute NMDA jump constants once (parameters are constant throughout simulation).
_alpha_np = np.asarray(u.get_mantissa(self.alpha * u.ms))
_tau_rise_np = np.asarray(u.get_mantissa(self.tau_rise_NMDA / u.ms))
_tau_decay_np = np.asarray(u.get_mantissa(self.tau_decay_NMDA / u.ms))
_k0, _k1 = self._nmda_jump_constants(_alpha_np, _tau_rise_np, _tau_decay_np)
self._k0_np = np.asarray(_k0, dtype=np.float64)
self._k1_np = np.asarray(_k1, dtype=np.float64)
@property
def receptor_types(self):
r"""Return dictionary of available receptor types.
Returns
-------
dict
Mapping from receptor name (str) to receptor ID (int).
Keys: ``'AMPA'``, ``'GABA'``, ``'NMDA'``. Values: 1, 2, 3.
"""
return dict(self.RECEPTOR_TYPES)
@property
def recordables(self):
r"""Return list of recordable state variable names.
Returns
-------
list of str
State variables that can be recorded during simulation:
``['V_m', 's_AMPA', 's_GABA', 's_NMDA', 'I_NMDA', 'I_AMPA', 'I_GABA']``.
"""
return list(self.RECORDABLES)
@classmethod
def _normalize_spike_receptor(cls, receptor):
if isinstance(receptor, str):
key = receptor.strip()
if key in cls.RECEPTOR_TYPES:
return cls.RECEPTOR_TYPES[key]
if key.isdigit():
receptor = int(key)
else:
raise ValueError(f'Unknown receptor label: {receptor}')
receptor = int(receptor)
if receptor < 1 or receptor > 3:
raise ValueError(f'Receptor type must be in [1, 3], got {receptor}.')
return receptor
def _validate_parameters(self):
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_AMPA)):
return
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time cannot be negative.')
if np.any(self.tau_AMPA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_GABA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_decay_NMDA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_rise_NMDA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.alpha <= 0.0 / u.ms):
raise ValueError('alpha > 0 required.')
if np.any(self.conc_Mg2 <= 0.0 * u.mM):
raise ValueError('Mg2 concentration must be strictly positive.')
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables for the neuron population.
Creates and initializes membrane potential, synaptic conductance states
(AMPA, GABA, NMDA), synaptic currents, refractory counters, NMDA presynaptic
helper state, adaptive RKF45 step size, and delayed current buffer.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
- All synaptic conductances initialize to 0 nS by default.
- Membrane potential initializes to -70 mV (near ``E_L``) by default.
- ``integration_step`` initializes to the simulation timestep ``dt``.
- ``last_spike_time`` initializes to -1e7 ms (far in the past).
- If ``ref_var=True``, a boolean ``refractory`` state is also created.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
V = braintools.init.param(self.V_initializer, self.varshape)
s_ampa = braintools.init.param(self.s_AMPA_initializer, self.varshape)
s_gaba = braintools.init.param(self.s_GABA_initializer, self.varshape)
s_nmda = braintools.init.param(self.s_NMDA_initializer, self.varshape)
self.V = brainstate.HiddenState(V)
self.s_AMPA = brainstate.HiddenState(s_ampa)
self.s_GABA = brainstate.HiddenState(s_gaba)
self.s_NMDA = brainstate.HiddenState(s_nmda)
self.I_NMDA = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype) * u.pA)
self.I_AMPA = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype) * u.pA)
self.I_GABA = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype) * u.pA)
self.s_NMDA_pre = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype))
self.spike_offset = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype))
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape)
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape)
self.refractory = brainstate.ShortTermState(refractory)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Compute spike output using surrogate gradient function.
Converts membrane potential to a differentiable spike signal using the
configured surrogate gradient function (``spk_fun``). The membrane potential
is scaled relative to threshold and reset before applying the surrogate.
Parameters
----------
V : ArrayLike, optional
Membrane potential (mV). If None, uses current ``self.V.value``.
Shape: ``(*in_size,)`` or ``(batch_size, *in_size)``.
Returns
-------
jax.numpy.ndarray
Spike signal (differentiable). Shape matches input ``V``.
Values in [0, 1] for typical surrogate functions (e.g., sigmoid-based).
Hard thresholding (Heaviside) gives binary {0, 1} values.
Notes
-----
- Scaling factor: :math:`(V - V_{th}) / (V_{th} - V_{reset})`.
- The surrogate function is differentiable during backpropagation but
appears as a step function during forward pass (for gradient flow).
- This method is called internally by :meth:`update` after integration
and threshold checking.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
@staticmethod
def _nmda_jump_constants(alpha, tau_rise, tau_decay):
r"""Compute NMDA spike offset jump constants k0 and k1.
Calculates precomputed constants for NMDA spike offset approximation
based on alpha-function rise dynamics. These constants are used to
compute the NMDA conductance jump :math:`\Delta s_{NMDA} = k_0 + k_1 s_{pre}`
when a neuron spikes.
Parameters
----------
alpha : float or numpy.ndarray
NMDA jump-shape parameter (1/ms). Can be scalar or array.
tau_rise : float or numpy.ndarray
NMDA rise time constant (ms). Can be scalar or array.
tau_decay : float or numpy.ndarray
NMDA decay time constant (ms). Can be scalar or array.
Returns
-------
k0 : float or numpy.ndarray
Constant term for NMDA offset (unitless). Shape matches inputs.
k1 : float or numpy.ndarray
Linear term for NMDA offset (unitless). Shape matches inputs.
Notes
-----
The constants are derived from the integral of the NMDA alpha-function
kernel:
.. math::
k_1 = \exp(-\alpha\tau_{rise}) - 1,
.. math::
k_0 = (\alpha\tau_{rise})^{\tau_{rise}/\tau_{decay}}
\gamma\Big(1 - \tau_{rise}/\tau_{decay}, \alpha\tau_{rise}\Big),
where :math:`\gamma(a, x)` is the lower incomplete gamma function.
These constants are precomputed once per update step and reused for all
neurons that spike during that step.
"""
dftype = brainstate.environ.dftype()
alpha_tau = alpha * tau_rise
tau_ratio = tau_rise / tau_decay
k1 = np.expm1(-alpha_tau)
a = 1.0 - tau_ratio
x = alpha_tau
a_j = jnp.asarray(a, dtype=dftype)
x_j = jnp.asarray(x, dtype=dftype)
lower_gamma = np.asarray(
jsp.special.gammainc(a_j, x_j) * jnp.exp(jsp.special.gammaln(a_j)),
dtype=dftype,
)
k0 = np.power(alpha_tau, tau_ratio) * lower_gamma
return k0, k1
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
Parameters
----------
state : DotDict
Keys: V, s_AMPA, s_GABA, s_NMDA -- ODE state variables.
extra : DotDict
Keys: spike_mask, r, unstable, i_stim -- mutable
auxiliary data carried through the integrator.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
# AMPA current: I_AMPA = (V - E_ex) * s_AMPA
i_ampa = state.s_AMPA * (state.V - self.E_ex)
# GABA current: I_GABA = (V - E_in) * s_GABA
i_gaba = state.s_GABA * (state.V - self.E_in)
# NMDA current with Mg2+ block
v_mV = state.V / u.mV
conc_mM = self.conc_Mg2 / u.mM
denom = 1.0 + conc_mM * u.math.exp(-0.062 * v_mV) / 3.57
i_nmda = state.s_NMDA * (state.V - self.E_ex) / denom
i_syn = i_ampa + i_gaba + i_nmda
dV = (-self.g_L * (state.V - self.E_L) - i_syn + extra.i_stim) / self.C_m
ds_AMPA = -state.s_AMPA / self.tau_AMPA
ds_GABA = -state.s_GABA / self.tau_GABA
ds_NMDA = -state.s_NMDA / self.tau_decay_NMDA
return DotDict(V=dV, s_AMPA=ds_AMPA, s_GABA=ds_GABA, s_NMDA=ds_NMDA)
def _event_fn(self, state, extra, accept):
"""In-loop spike detection, reset, and refractory handling.
Parameters
----------
state : DotDict
Keys: V, s_AMPA, s_GABA, s_NMDA -- ODE state variables.
extra : DotDict
Keys: spike_mask, r, unstable, i_stim, s_nmda_pre, last_spike_time,
k0, k1, t_spike.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(new_state, new_extra) DotDicts with updated spike/reset/refractory info.
"""
unstable = extra.unstable | jnp.any(
accept & ((state.V < -1e3 * u.mV) | (state.s_AMPA < -1e6 * u.nS) | (state.s_AMPA > 1e6 * u.nS))
)
# Spike detection: non-refractory (r <= 0 at step start), above threshold, not already spiked.
# The ODE integrates freely; V reset and refractory clamping are applied post-integration.
spike_now = accept & (extra.r <= 0) & (state.V >= self.V_th) & ~extra.spike_mask
spike_mask = extra.spike_mask | spike_now
# NMDA spike offset computation on spike
dt_since_last = extra.t_spike - extra.last_spike_time
s_pre_decayed = extra.s_nmda_pre * u.math.exp(-dt_since_last / self.tau_decay_NMDA)
offset = extra.k0 + extra.k1 * s_pre_decayed
s_pre_updated = s_pre_decayed + offset
# Only apply NMDA updates on spike
new_s_nmda_pre = u.math.where(spike_now, s_pre_updated, extra.s_nmda_pre)
new_spike_offset = u.math.where(spike_now, offset, extra.spike_offset)
new_last_spike_time = u.math.where(spike_now, extra.t_spike, extra.last_spike_time)
new_extra = DotDict({
**extra,
'spike_mask': spike_mask,
'unstable': unstable,
's_nmda_pre': new_s_nmda_pre,
'spike_offset': new_spike_offset,
'last_spike_time': new_last_spike_time,
})
return state, new_extra
def _parse_spike_events(self, spike_events: Iterable, state_shape):
r"""Parse explicit spike events into per-receptor conductance increments.
Parameters
----------
spike_events : Iterable or None
Incoming spike events. Each event is a tuple or dict with receptor
type, weight, optional offset, and optional sender_model.
state_shape : tuple
Shape of the state arrays for broadcasting.
Returns
-------
s_ampa, s_gaba, s_nmda : jax.numpy.ndarray
Conductance increments for each receptor type (nS).
"""
dftype = brainstate.environ.dftype()
s_ampa = jnp.zeros(state_shape, dtype=dftype) * u.nS
s_gaba = jnp.zeros(state_shape, dtype=dftype) * u.nS
s_nmda = jnp.zeros(state_shape, dtype=dftype) * u.nS
if spike_events is None:
return s_ampa, s_gaba, s_nmda
for ev in spike_events:
sender_model = 'iaf_bw_2001'
offset = 1.0
if isinstance(ev, dict):
receptor = ev.get('receptor_type', ev.get('receptor', 'AMPA'))
weight = ev.get('weight', 0.0 * u.nS)
sender_model = ev.get('sender_model', 'iaf_bw_2001')
offset = ev.get('offset', ev.get('nmda_offset', 1.0))
else:
if len(ev) == 2:
receptor, weight = ev
elif len(ev) == 3:
receptor, weight, offset = ev
elif len(ev) == 4:
receptor, weight, offset, sender_model = ev
else:
raise ValueError('Spike event tuples must have length 2, 3, or 4.')
receptor_id = self._normalize_spike_receptor(receptor)
if receptor_id == self.AMPA:
s_ampa = s_ampa + weight
elif receptor_id == self.GABA:
s_gaba = s_gaba + weight
else:
if sender_model != 'iaf_bw_2001':
raise ValueError(
'For NMDA synapses in iaf_bw_2001, pre-synaptic neuron must also be of type iaf_bw_2001.'
)
s_nmda = s_nmda + weight * offset
return s_ampa, s_gaba, s_nmda
[docs]
def update(self, x=0. * u.pA, spike_events=None):
r"""Advance the neuron state by one simulation timestep.
Performs a complete update cycle including: (1) RKF45 integration of
ODEs, (2) reception of AMPA/GABA/NMDA spike events, (3) threshold
detection and spike emission, (4) refractory period handling, (5) NMDA
spike offset computation, and (6) delayed current buffering.
Parameters
----------
x : saiunit.Quantity, optional
External input current (pA). Can be scalar or array matching population
shape. This current is buffered and applied in the **next** timestep
(one-step delay, matching NEST ring-buffer semantics). Default: 0 pA.
spike_events : list of tuple or dict, optional
Incoming spike events from presynaptic neurons. Each event can be:
- Tuple: ``(receptor, weight)`` or ``(receptor, weight, offset)`` or
``(receptor, weight, offset, sender_model)``
- Dict: ``{'receptor_type': ..., 'weight': ..., 'offset': ...,
'sender_model': ...}``
Receptor types: ``'AMPA'`` or ``1``, ``'GABA'`` or ``2``, ``'NMDA'``
or ``3``. Weight units: nS (conductance). Offset (for NMDA only):
presynaptic NMDA spike offset (unitless, default 1.0). Sender model
(for NMDA only): must be ``'iaf_bw_2001'``.
If None, no spike events are processed. Default: None.
Returns
-------
jax.numpy.ndarray
Spike output (differentiable). Shape: ``(*in_size,)``.
Values in [0, 1] for typical surrogate functions.
Raises
------
ValueError
If an NMDA spike event has ``sender_model != 'iaf_bw_2001'``. Only
``iaf_bw_2001`` neurons compute NMDA spike offsets; other neuron
types cannot send NMDA spikes to this model.
Notes
-----
**Update order (matching NEST):**
1. **Integration**: Integrate ODEs using adaptive RKF45 from :math:`t`
to :math:`t + dt`. The persistent ``integration_step`` is adapted
per neuron based on local error.
2. **Spike reception**: Add incoming spike weights (scaled by offset
for NMDA) to ``s_AMPA``, ``s_GABA``, ``s_NMDA``.
3. **Refractory/threshold**:
- If in refractory period (``refractory_step_count > 0``): clamp
:math:`V_m` to :math:`V_{reset}`, decrement counter.
- Else: check threshold :math:`V_m \geq V_{th}`. If crossed, emit
spike, reset :math:`V_m \leftarrow V_{reset}`, set refractory
counter, compute NMDA spike offset.
4. **Current buffering**: Store input current ``x`` (plus any registered
current inputs) into ``I_stim`` buffer for **next** step.
**NMDA spike offset computation:**
When this neuron spikes, the NMDA spike offset :math:`\Delta s_{NMDA}`
is computed using the presynaptic helper state ``s_NMDA_pre``:
.. math::
s_{pre} \leftarrow s_{pre} \exp(-\Delta t / \tau_{NMDA,decay}),
.. math::
\Delta s_{NMDA} = k_0 + k_1 s_{pre},
where :math:`\Delta t = t_{spike} - t_{last}` and :math:`k_0, k_1` are
precomputed constants. The updated ``s_NMDA_pre`` is stored for the
next spike. The offset :math:`\Delta s_{NMDA}` is exposed as
``spike_offset`` and should be passed to downstream NMDA connections.
**Current delay:**
The external current ``x`` is stored in ``I_stim`` and applied in the
**next** timestep. This one-step delay matches NEST's ring-buffer
semantics. Current inputs registered via ``add_current_input`` are
summed with ``x`` and delayed together.
**Integration notes:**
- RKF45 uses local error tolerance ``gsl_error_tol`` (default 1e-3).
- The adaptive step size ``integration_step`` is persistent per neuron
and typically stabilizes after a few milliseconds.
- Maximum iterations: 10000 per timestep (prevents infinite loops).
- Minimum step size: 1e-8 ms (prevents numerical instability).
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Read state variables with their natural units.
V = self.V.value # mV
s_AMPA = self.s_AMPA.value # nS
s_GABA = self.s_GABA.value # nS
s_NMDA = self.s_NMDA.value # nS
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
h = self.integration_step.value # ms
s_nmda_pre = self.s_NMDA_pre.value
last_spike_time = self.last_spike_time.value
spike_offset_prev = self.spike_offset.value
# Current input for next step (one-step delay).
new_i_stim = self.sum_current_inputs(x, self.V.value) # pA
# Use pre-computed NMDA jump constants (computed once in __init__).
k0 = jnp.asarray(self._k0_np, dtype=dftype)
k1 = jnp.asarray(self._k1_np, dtype=dftype)
# Adaptive RKF45 integration via generic integrator.
ode_state = DotDict(V=V, s_AMPA=s_AMPA, s_GABA=s_GABA, s_NMDA=s_NMDA)
extra = DotDict(
spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_),
r=r,
unstable=jnp.array(False),
i_stim=i_stim,
s_nmda_pre=s_nmda_pre,
spike_offset=jnp.zeros(self.varshape, dtype=dftype),
last_spike_time=last_spike_time,
k0=k0,
k1=k1,
t_spike=t + dt,
)
ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra)
V_ode = ode_state.V # Free ODE output — no refractory/spike reset applied yet.
s_AMPA, s_GABA, s_NMDA = ode_state.s_AMPA, ode_state.s_GABA, ode_state.s_NMDA
spike_mask, r_init, unstable = extra.spike_mask, extra.r, extra.unstable
s_nmda_pre = extra.s_nmda_pre
spike_offset_new = extra.spike_offset
last_spike_time = extra.last_spike_time
# Post-loop stability check.
brainstate.transform.jit_error_if(
jnp.any(unstable), 'Numerical instability in iaf_bw_2001 dynamics.'
)
# Compute synaptic currents from FREE ODE output (before V reset/refractory clamping).
# This matches NEST semantics: recorded currents reflect the ODE solution.
v_mV = V_ode / u.mV
conc_mM = self.conc_Mg2 / u.mM
denom = 1.0 + conc_mM * u.math.exp(-0.062 * v_mV) / 3.57
I_AMPA_val = s_AMPA * (V_ode - self.E_ex)
I_GABA_val = s_GABA * (V_ode - self.E_in)
I_NMDA_val = s_NMDA * (V_ode - self.E_ex) / denom
# Apply refractory / spike reset post-integration (matching NEST).
# V is clamped to V_reset if the neuron spiked this step OR is still refractory.
V = u.math.where(spike_mask | (r_init > 0), self.V_reset, V_ode)
# Update refractory counter:
# - spike this step and t_ref > 0 → start refractory (ref_count steps)
# - already refractory → decrement
# - otherwise → keep at 0
r = u.math.where(
spike_mask & (self.ref_count > 0),
self.ref_count,
u.math.where(r_init > 0, r_init - 1, r_init),
)
# Synaptic spike inputs (applied after integration and current recording).
# Parse explicit spike events.
ev_ampa, ev_gaba, ev_nmda = self._parse_spike_events(spike_events, self.varshape)
# Parse registered delta inputs by receptor label.
w_ampa = self.sum_delta_inputs(u.math.zeros_like(self.s_AMPA.value), label='AMPA')
w_gaba = self.sum_delta_inputs(u.math.zeros_like(self.s_GABA.value), label='GABA')
w_nmda = self.sum_delta_inputs(u.math.zeros_like(self.s_NMDA.value), label='NMDA')
# Apply synaptic spike inputs.
s_AMPA = s_AMPA + ev_ampa + w_ampa
s_GABA = s_GABA + ev_gaba + w_gaba
s_NMDA = s_NMDA + ev_nmda + w_nmda
# Write back state.
self.V.value = V
self.s_AMPA.value = s_AMPA
self.s_GABA.value = s_GABA
self.s_NMDA.value = s_NMDA
self.I_AMPA.value = I_AMPA_val
self.I_GABA.value = I_GABA_val
self.I_NMDA.value = I_NMDA_val
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype)
self.integration_step.value = h
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
self.s_NMDA_pre.value = s_nmda_pre
self.spike_offset.value = spike_offset_new
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
return u.math.asarray(spike_mask, dtype=dftype)