# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Hashable, Iterable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep
__all__ = [
'iaf_bw_2001_exact',
]
class iaf_bw_2001_exact(NESTNeuron):
r"""NEST-compatible conductance-based LIF neuron with exact per-synapse NMDA dynamics.
This model implements the Brunel-Wang (2001) neuron with exact NMDA kinetics, maintaining
separate rise and decay variables for each NMDA synapse without presynaptic-jump approximation.
Each NMDA connection is assigned a unique port with a fixed weight, enforcing NEST's constraint
that NMDA connections cannot be added after the first simulation step.
Parameters
----------
in_size : int, tuple of int, Sequence of int
Population shape. Defines the number and arrangement of neurons.
E_L : ArrayLike, optional
Leak reversal potential. Default: -70 mV.
Determines the resting potential in the absence of input.
E_ex : ArrayLike, optional
Excitatory reversal potential. Default: 0 mV.
Reversal potential for AMPA and NMDA receptors.
E_in : ArrayLike, optional
Inhibitory reversal potential. Default: -70 mV.
Reversal potential for GABA receptors.
V_th : ArrayLike, optional
Spike threshold potential. Default: -55 mV.
Membrane potential at which a spike is emitted.
V_reset : ArrayLike, optional
Reset potential. Default: -60 mV.
Membrane potential immediately after spike emission. Must be < V_th.
C_m : ArrayLike, optional
Membrane capacitance. Default: 500 pF.
Must be strictly positive.
g_L : ArrayLike, optional
Leak conductance. Default: 25 nS.
Conductance through passive leak channels.
t_ref : ArrayLike, optional
Absolute refractory period duration. Default: 2 ms.
Time after spike during which membrane is clamped to V_reset.
tau_AMPA : ArrayLike, optional
AMPA decay time constant. Default: 2 ms.
Governs exponential decay of AMPA conductance. Must be > 0.
tau_GABA : ArrayLike, optional
GABA decay time constant. Default: 5 ms.
Governs exponential decay of GABA conductance. Must be > 0.
tau_rise_NMDA : ArrayLike, optional
NMDA rise time constant. Default: 2 ms.
Time constant for NMDA activation variable x_j. Must be > 0.
tau_decay_NMDA : ArrayLike, optional
NMDA decay time constant. Default: 100 ms.
Time constant for NMDA gating variable s_j. Must be > 0.
alpha : ArrayLike, optional
NMDA rise coupling strength. Default: 0.5 / ms.
Scales the coupling between rise (x_j) and gating (s_j) variables. Must be > 0.
conc_Mg2 : ArrayLike, optional
Extracellular magnesium concentration. Default: 1 mM.
Controls voltage-dependent NMDA blockade. Must be > 0.
gsl_error_tol : ArrayLike, optional
RKF45 local error tolerance. Default: 1e-3.
Controls adaptive step size in Runge-Kutta-Fehlberg integration. Must be > 0.
Smaller values improve accuracy at the cost of more iterations.
V_initializer : Callable, optional
Membrane potential initializer. Default: Constant(-70 mV).
Function that generates initial V_m values.
s_AMPA_initializer : Callable, optional
AMPA conductance state initializer. Default: Constant(0 nS).
Function that generates initial s_AMPA values.
s_GABA_initializer : Callable, optional
GABA conductance state initializer. Default: Constant(0 nS).
Function that generates initial s_GABA values.
spk_fun : Callable, optional
Surrogate gradient function for spike generation. Default: ReluGrad().
Maps scaled voltage to differentiable spike output.
spk_reset : str, optional
Spike reset mode. Default: 'hard'.
- 'hard': Stop gradient through reset (matches NEST)
- 'soft': Gradient flows through reset (V -= V_th)
ref_var : bool, optional
If True, expose boolean refractory state variable. Default: False.
Adds a `refractory` attribute for monitoring refractory state.
name : str, optional
Module name. Default: None (auto-generated).
Raises
------
ValueError
If V_reset >= V_th, or any of C_m, tau_*, alpha, conc_Mg2, gsl_error_tol <= 0.
ValueError
If attempting to change NMDA port weights after first registration.
ValueError
If attempting to add new NMDA ports after first :meth:`update` call.
ValueError
If NMDA port is not hashable.
ValueError
If spike event format is invalid.
See Also
--------
iaf_bw_2001 : Approximate version using presynaptic-jump NMDA dynamics
iaf_cond_exp : Simpler conductance-based LIF without NMDA
aeif_cond_alpha : Adaptive exponential IF with alpha-shaped conductances
Parameter Mapping
-----------------
============================ ======================== ============================================
**NEST Parameter** **brainpy.state** **Notes**
============================ ======================== ============================================
``E_L`` ``E_L`` Leak reversal potential (mV)
``E_ex`` ``E_ex`` Excitatory reversal (mV)
``E_in`` ``E_in`` Inhibitory reversal (mV)
``V_th`` ``V_th`` Spike threshold (mV)
``V_reset`` ``V_reset`` Reset potential (mV)
``C_m`` ``C_m`` Membrane capacitance (pF)
``g_L`` ``g_L`` Leak conductance (nS)
``t_ref`` ``t_ref`` Refractory period (ms)
``tau_AMPA`` ``tau_AMPA`` AMPA decay time (ms)
``tau_GABA`` ``tau_GABA`` GABA decay time (ms)
``tau_rise_NMDA`` ``tau_rise_NMDA`` NMDA rise time (ms)
``tau_decay_NMDA`` ``tau_decay_NMDA`` NMDA decay time (ms)
``alpha`` ``alpha`` NMDA coupling (1/ms)
``conc_Mg2`` ``conc_Mg2`` Mg2+ concentration (mM)
``gsl_error_tol`` ``gsl_error_tol`` RKF45 tolerance (dimensionless)
============================ ======================== ============================================
Mathematical Model
------------------
**1. Membrane Dynamics**
The subthreshold membrane potential evolves according to:
.. math::
C_m \frac{dV_m}{dt} = -g_L(V_m - E_L) - I_{syn} + I_{stim}
where :math:`I_{syn} = I_{AMPA} + I_{GABA} + I_{NMDA}` is the total synaptic current.
**2. Synaptic Currents**
AMPA and GABA currents are ohmic:
.. math::
I_{AMPA} &= (V_m - E_{ex}) s_{AMPA} \\
I_{GABA} &= (V_m - E_{in}) s_{GABA}
NMDA current includes voltage-dependent Mg2+ blockade:
.. math::
I_{NMDA} = \frac{(V_m - E_{ex})}{1 + [Mg^{2+}]\exp(-0.062V_m)/3.57} \sum_j w_j s_j
where :math:`j` indexes individual NMDA synapses, :math:`w_j` is the fixed weight for port :math:`j`,
and :math:`s_j` is the gating variable for that synapse.
**3. Synaptic Gating Variables**
AMPA and GABA conductances decay exponentially:
.. math::
\frac{ds_{AMPA}}{dt} &= -\frac{s_{AMPA}}{\tau_{AMPA}} \\
\frac{ds_{GABA}}{dt} &= -\frac{s_{GABA}}{\tau_{GABA}}
Each NMDA synapse :math:`j` has dual-timescale kinetics:
.. math::
\frac{dx_j}{dt} &= -\frac{x_j}{\tau_{NMDA,rise}} \\
\frac{ds_j}{dt} &= -\frac{s_j}{\tau_{NMDA,decay}} + \alpha x_j (1-s_j)
where :math:`x_j` is the rise variable (fast activation) and :math:`s_j` is the gating variable
(slow inactivation with saturation).
**4. Spike Generation and Reset**
When :math:`V_m \geq V_{th}` and the neuron is not refractory:
- Emit a spike
- Set :math:`V_m \leftarrow V_{reset}`
- Enter refractory state for :math:`t_{ref}` ms
During refractoriness, :math:`V_m` is clamped to :math:`V_{reset}`.
**5. Numerical Integration**
The continuous dynamics are integrated using adaptive Runge-Kutta-Fehlberg (RKF45) with:
- 4th and 5th order embedded methods for error estimation
- Persistent step size :math:`h` that adapts to maintain local error < ``gsl_error_tol``
- Minimum step size :math:`h_{min} = 10^{-8}` ms
- Maximum iterations per simulation step: 10,000
**NMDA Port Semantics**
NEST assigns each NMDA connection a unique receptor port at connect time and prohibits adding
new NMDA connections after the first simulation step. This implementation mirrors that behavior:
- Each NMDA event requires a ``port`` identifier (any hashable value)
- The first event for a new port registers that port with the provided weight
- Subsequent events to the same port must use the same weight (enforced)
- New ports can only be added before the first :meth:`update` call
- AMPA/GABA events do not use ports (weights accumulate directly)
**Spike Event Formats**
The ``spike_events`` parameter accepts multiple formats:
**Tuple formats:**
- ``(receptor, weight)`` --- receptor in {1, 2, 3} or {'AMPA', 'GABA', 'NMDA'}
- ``(receptor, weight, third)`` --- ``third`` is multiplicity for AMPA/GABA, port for NMDA
- ``(receptor, weight, port, multiplicity)`` --- full NMDA specification
**Dict format:**
- Required keys: ``receptor_type`` or ``receptor`` (1/2/3 or 'AMPA'/'GABA'/'NMDA'), ``weight``
- Optional keys: ``multiplicity`` (default 1.0), ``port``/``rport``/``synapse_id`` (for NMDA)
**Update Ordering (matches NEST)**
Each :meth:`update` call executes in this order:
1. **Integrate ODEs** on :math:`(t, t+dt]` using RKF45 with persistent step size
2. **Apply spike jumps**: add to ``s_AMPA``, ``s_GABA``, and ``x_j`` for each NMDA port
3. **Threshold check and reset**: emit spikes, reset voltage, update refractory countdown
4. **Store external current**: buffer ``I_stim`` for next step (one-step delay)
**Recordable Variables**
- ``V_m`` --- Membrane potential (mV)
- ``s_AMPA`` --- AMPA conductance state (nS)
- ``s_GABA`` --- GABA conductance state (nS)
- ``s_NMDA`` --- Weighted sum of NMDA gating variables (nS), :math:`\sum_j w_j s_j`
- ``I_AMPA`` --- AMPA current (pA)
- ``I_GABA`` --- GABA current (pA)
- ``I_NMDA`` --- NMDA current (pA)
Additional State Variables
--------------------------
- ``x_NMDA`` --- NMDA rise variables for each port (shape: ``[*in_size, n_ports]``)
- ``s_NMDA_components`` --- NMDA gating variables for each port (shape: ``[*in_size, n_ports]``)
- ``nmda_weights`` --- Fixed weights for each NMDA port (shape: ``[*in_size, n_ports]``)
- ``refractory_step_count`` --- Remaining refractory steps (int32)
- ``integration_step`` --- Persistent RKF45 step size (ms)
- ``I_stim`` --- One-step delayed external current buffer (pA)
- ``refractory`` --- Boolean refractory indicator (only if ``ref_var=True``)
**Performance Considerations:**
- RKF45 integration is performed per-neuron in NumPy (not vectorized)
- Computational cost scales linearly with the number of NMDA ports
- Large ``gsl_error_tol`` reduces accuracy but improves speed
- This model is significantly slower than ``iaf_bw_2001`` due to per-synapse state
**Comparison to iaf_bw_2001:**
- ``iaf_bw_2001`` approximates all NMDA synapses with a single pair of state variables
- ``iaf_bw_2001_exact`` tracks rise and decay for each NMDA connection separately
- Use ``iaf_bw_2001_exact`` when NMDA synapse heterogeneity matters (e.g., detailed working memory models)
- Use ``iaf_bw_2001`` for large-scale simulations where approximation is acceptable
References
----------
.. [1] Wang X-J (1999). Synaptic basis of cortical persistent activity:
The importance of NMDA receptors to working memory.
Journal of Neuroscience, 19(21):9587-9603.
DOI: https://doi.org/10.1523/JNEUROSCI.19-21-09587.1999
.. [2] Brunel N, Wang X-J (2001). Effects of neuromodulation in a cortical
network model of object working memory dominated by recurrent
inhibition. Journal of Computational Neuroscience, 11(1):63-85.
DOI: https://doi.org/10.1023/A:1011204814320
.. [3] Wang X-J (2002). Probabilistic decision making by slow
reverberation in cortical circuits. Neuron, 36(5):955-968.
DOI: https://doi.org/10.1016/S0896-6273(02)01092-9
.. [4] NEST Simulator. Models: iaf_bw_2001_exact.
https://nest-simulator.readthedocs.io/en/stable/models/iaf_bw_2001_exact.html
Examples
--------
**Basic usage with AMPA input:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=10)
>>> net.init_all_states()
>>> # Apply AMPA input spike
>>> spike = bp.iaf_bw_2001_exact.get_spike(net(spike_events=[(1, 100*u.nS)]))
>>> print(net.V.value) # doctest: +SKIP
**NMDA connections with unique ports:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=5)
>>> net.init_all_states()
>>> # Register two NMDA ports with different weights
>>> events = [
... (3, 50*u.nS, 'port_A', 1.0), # NMDA port A, weight 50 nS
... (3, 75*u.nS, 'port_B', 1.0), # NMDA port B, weight 75 nS
... ]
>>> spike = net(spike_events=events)
>>> print(net.s_NMDA_components.value.shape) # doctest: +SKIP
(5, 2) # 5 neurons x 2 NMDA ports
**Mixing AMPA, GABA, and NMDA:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=1, V_th=-50*u.mV)
>>> net.init_all_states()
>>> events = [
... {'receptor': 'AMPA', 'weight': 200*u.nS, 'multiplicity': 2.0},
... {'receptor': 'GABA', 'weight': 100*u.nS},
... {'receptor': 'NMDA', 'weight': 50*u.nS, 'port': 0},
... ]
>>> for _ in range(100):
... spike = net(spike_events=events if _ == 10 else None)
>>> print(net.last_spike_time.value) # doctest: +SKIP
**Monitoring refractory state:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=3, ref_var=True, t_ref=5*u.ms)
>>> net.init_all_states()
>>> net.V.value = net.V_th + 1*u.mV # Force spike
>>> spike = net()
>>> print(net.refractory.value) # doctest: +SKIP
[True True True]
"""
__module__ = 'brainpy.state'
AMPA = 1
GABA = 2
NMDA = 3
RECEPTOR_TYPES = {
'AMPA': AMPA,
'GABA': GABA,
'NMDA': NMDA,
}
RECORDABLES = (
'V_m',
's_AMPA',
's_GABA',
's_NMDA',
'I_NMDA',
'I_AMPA',
'I_GABA',
)
_ATOL = 1e-3
_MIN_H = 1e-8 * u.ms # ms
_MAX_ITERS = 10000
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
E_ex: ArrayLike = 0. * u.mV,
E_in: ArrayLike = -70. * u.mV,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -60. * u.mV,
C_m: ArrayLike = 500. * u.pF,
g_L: ArrayLike = 25. * u.nS,
t_ref: ArrayLike = 2. * u.ms,
tau_AMPA: ArrayLike = 2. * u.ms,
tau_GABA: ArrayLike = 5. * u.ms,
tau_rise_NMDA: ArrayLike = 2. * u.ms,
tau_decay_NMDA: ArrayLike = 100. * u.ms,
alpha: ArrayLike = 0.5 / u.ms,
conc_Mg2: ArrayLike = 1.0 * u.mM,
gsl_error_tol: ArrayLike = 1e-3,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
s_AMPA_initializer: Callable = braintools.init.Constant(0. * u.nS),
s_GABA_initializer: Callable = braintools.init.Constant(0. * u.nS),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.E_ex = braintools.init.param(E_ex, self.varshape)
self.E_in = braintools.init.param(E_in, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.g_L = braintools.init.param(g_L, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.tau_AMPA = braintools.init.param(tau_AMPA, self.varshape)
self.tau_GABA = braintools.init.param(tau_GABA, self.varshape)
self.tau_rise_NMDA = braintools.init.param(tau_rise_NMDA, self.varshape)
self.tau_decay_NMDA = braintools.init.param(tau_decay_NMDA, self.varshape)
self.alpha = braintools.init.param(alpha, self.varshape)
self.conc_Mg2 = braintools.init.param(conc_Mg2, self.varshape)
self.gsl_error_tol = gsl_error_tol
self.V_initializer = V_initializer
self.s_AMPA_initializer = s_AMPA_initializer
self.s_GABA_initializer = s_GABA_initializer
self.ref_var = ref_var
self._nmda_port_index = {}
self._updates_started = False
self._validate_parameters()
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=brainstate.environ.get_dt()
)
# other variable
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
@property
def receptor_types(self):
r"""Mapping of receptor names to numeric identifiers.
Returns
-------
dict
Dictionary mapping {'AMPA': 1, 'GABA': 2, 'NMDA': 3}.
"""
return dict(self.RECEPTOR_TYPES)
@property
def recordables(self):
r"""List of variables available for recording.
Returns
-------
list of str
['V_m', 's_AMPA', 's_GABA', 's_NMDA', 'I_NMDA', 'I_AMPA', 'I_GABA'].
"""
return list(self.RECORDABLES)
@classmethod
def _normalize_spike_receptor(cls, receptor):
r"""Normalize receptor identifier to numeric code.
Parameters
----------
receptor : str or int
Receptor identifier. Accepts 'AMPA', 'GABA', 'NMDA', or numeric codes 1/2/3.
Returns
-------
int
Numeric receptor code (1=AMPA, 2=GABA, 3=NMDA).
Raises
------
ValueError
If receptor is not recognized or is out of valid range [1, 3].
"""
if isinstance(receptor, str):
key = receptor.strip()
if key in cls.RECEPTOR_TYPES:
return cls.RECEPTOR_TYPES[key]
if key.isdigit():
receptor = int(key)
else:
raise ValueError(f'Unknown receptor label: {receptor}')
receptor = int(receptor)
if receptor < cls.AMPA or receptor > cls.NMDA:
raise ValueError(f'Receptor type must be in [1, 3], got {receptor}.')
return receptor
@staticmethod
def _normalize_nmda_port(port) -> Hashable:
r"""Normalize NMDA port identifier to hashable value.
Parameters
----------
port : Hashable or None
NMDA port identifier. Can be int, str, or any hashable type.
If None, defaults to port 0.
Returns
-------
Hashable
Normalized port identifier. Numeric strings converted to int,
None converted to 0, other hashable values returned as-is.
Raises
------
ValueError
If port is not hashable.
"""
if port is None:
return 0
if isinstance(port, str):
p = port.strip()
if p.isdigit():
return int(p)
return p
try:
hash(port)
except TypeError as e:
raise ValueError(f'NMDA port must be hashable, got {type(port)}.') from e
return port
def _validate_parameters(self):
r"""Validate model parameters at initialization.
Raises
------
ValueError
If V_reset >= V_th.
ValueError
If C_m, tau_AMPA, tau_GABA, tau_rise_NMDA, tau_decay_NMDA, alpha,
conc_Mg2, or gsl_error_tol are non-positive.
ValueError
If t_ref is negative.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_AMPA)):
return
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time cannot be negative.')
if np.any(self.tau_AMPA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_GABA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_rise_NMDA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.tau_decay_NMDA <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.alpha <= 0.0 / u.ms):
raise ValueError('alpha > 0 required.')
if np.any(self.conc_Mg2 <= 0.0 * u.mM):
raise ValueError('Mg2 concentration must be strictly positive.')
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
def _nmda_num_ports(self):
if hasattr(self, 'x_NMDA'):
return int(self.x_NMDA.value.shape[-1])
return 0
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables.
Creates and initializes membrane potential, synaptic conductances, currents,
NMDA port arrays (initially empty), refractory state, and integration step size.
NMDA port registry is cleared.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
- NMDA port arrays (x_NMDA, s_NMDA_components, nmda_weights) start empty (shape: [..., 0])
- Ports are allocated dynamically when first NMDA spike arrives
- Clears the internal ``_nmda_port_index`` registry
- Resets ``_updates_started`` flag to False
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
V = braintools.init.param(self.V_initializer, self.varshape)
s_ampa = braintools.init.param(self.s_AMPA_initializer, self.varshape)
s_gaba = braintools.init.param(self.s_GABA_initializer, self.varshape)
self.V = brainstate.HiddenState(V)
self.s_AMPA = brainstate.HiddenState(s_ampa)
self.s_GABA = brainstate.HiddenState(s_gaba)
zeros = u.math.zeros(self.varshape, dtype=dftype)
self.s_NMDA = brainstate.ShortTermState(zeros * u.nS)
self.I_NMDA = brainstate.ShortTermState(zeros * u.pA)
self.I_AMPA = brainstate.ShortTermState(zeros * u.pA)
self.I_GABA = brainstate.ShortTermState(zeros * u.pA)
self.x_NMDA = brainstate.ShortTermState(np.zeros(self.varshape + (0,), dtype=dftype))
self.s_NMDA_components = brainstate.ShortTermState(np.zeros(self.varshape + (0,), dtype=dftype))
self.nmda_weights = brainstate.ShortTermState(np.zeros(self.varshape + (0,), dtype=dftype))
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape)
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
self._nmda_port_index = {}
self._updates_started = False
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape)
self.refractory = brainstate.ShortTermState(refractory)
[docs]
def reset_state(self, batch_size: int = None, **kwargs):
r"""Reset all state variables to initial values.
Unlike :meth:`init_state`, this preserves NMDA port structure (number of ports
and their weights remain unchanged). Resets voltage, conductances, currents,
NMDA gating variables, refractory state, and integration step size.
Parameters
----------
batch_size : int, optional
Batch dimension size for state variables. Default: None (no batching).
If provided, reshapes state variables with a leading batch dimension.
**kwargs
Additional keyword arguments (currently unused).
Notes
-----
- NMDA port count and weights are preserved (but x_NMDA and s_NMDA_components are zeroed)
- Does NOT clear ``_nmda_port_index`` (port registry persists)
- Does NOT reset ``_updates_started`` flag
"""
self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size)
self.s_AMPA.value = braintools.init.param(self.s_AMPA_initializer, self.varshape, batch_size)
self.s_GABA.value = braintools.init.param(self.s_GABA_initializer, self.varshape, batch_size)
state_shape = self.V.value.shape
dftype = brainstate.environ.dftype()
zeros = np.zeros(state_shape, dtype=dftype)
self.s_NMDA.value = zeros * u.nS
self.I_NMDA.value = zeros * u.pA
self.I_AMPA.value = zeros * u.pA
self.I_GABA.value = zeros * u.pA
n_ports = self._nmda_num_ports()
self.x_NMDA.value = np.zeros(state_shape + (n_ports,), dtype=dftype)
self.s_NMDA_components.value = np.zeros(state_shape + (n_ports,), dtype=dftype)
self.last_spike_time.value = braintools.init.param(
braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size
)
ref_steps = braintools.init.param(braintools.init.Constant(0), self.varshape, batch_size)
ditype = brainstate.environ.ditype()
self.refractory_step_count.value = u.math.asarray(ref_steps, dtype=ditype)
dt = brainstate.environ.get_dt()
self.integration_step.value = braintools.init.param(
braintools.init.Constant(dt), self.varshape, batch_size
)
self.I_stim.value = braintools.init.param(
braintools.init.Constant(0. * u.pA), self.varshape, batch_size
)
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size)
self.refractory.value = refractory
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate differentiable spike output from membrane potential.
Scales voltage relative to threshold and applies surrogate gradient function
for gradient-based learning. Voltage is scaled linearly between V_reset (0)
and V_th (1).
Parameters
----------
V : ArrayLike, optional
Membrane potential (mV). Default: None (uses current ``self.V.value``).
Shape must match ``self.varshape`` or be broadcastable to it.
Returns
-------
ArrayLike
Differentiable spike output in [0, 1]. Shape matches input voltage.
Values close to 1 indicate spiking; values close to 0 indicate quiescence.
Exact output depends on ``self.spk_fun`` (e.g., ReLU, sigmoid, etc.).
Notes
-----
- Used internally during :meth:`update` to compute spike output before reset
- Scaling formula: :math:`v_{scaled} = (V - V_{th}) / (V_{th} - V_{reset})`
- For hard reset mode, actual spike detection uses :math:`V \geq V_{th}`
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
def _ensure_nmda_port(self, port: Hashable, weight_np: np.ndarray, state_shape):
dftype = brainstate.environ.dftype()
if port in self._nmda_port_index:
idx = self._nmda_port_index[port]
current_weight = np.asarray(self.nmda_weights.value[..., idx], dtype=dftype)
if np.any(current_weight != weight_np):
raise ValueError('iaf_bw_2001_exact requires constant weights per NMDA port.')
return idx
if self._updates_started:
raise ValueError('NMDA ports can only be added before the first call to update().')
idx = self._nmda_num_ports()
self._nmda_port_index[port] = idx
zero_channel = np.zeros(state_shape + (1,), dtype=dftype)
x_old = np.asarray(self.x_NMDA.value, dtype=dftype)
s_old = np.asarray(self.s_NMDA_components.value, dtype=dftype)
w_old = np.asarray(self.nmda_weights.value, dtype=dftype)
self.x_NMDA.value = np.concatenate([x_old, zero_channel], axis=-1)
self.s_NMDA_components.value = np.concatenate([s_old, zero_channel], axis=-1)
self.nmda_weights.value = np.concatenate([w_old, np.expand_dims(weight_np, axis=-1)], axis=-1)
return idx
def _parse_spike_events(self, spike_events: Iterable, state_shape):
dftype = brainstate.environ.dftype()
ds_ampa = np.zeros(state_shape, dtype=dftype)
ds_gaba = np.zeros(state_shape, dtype=dftype)
nmda_mult = np.zeros(state_shape + (self._nmda_num_ports(),), dtype=dftype)
if spike_events is None:
return ds_ampa, ds_gaba, nmda_mult
for ev in spike_events:
receptor = 'AMPA'
weight = 0.0 * u.nS
multiplicity = 1.0
port = None
if isinstance(ev, dict):
receptor = ev.get('receptor_type', ev.get('receptor', 'AMPA'))
weight = ev.get('weight', 0.0 * u.nS)
multiplicity = ev.get('multiplicity', 1.0)
port = ev.get('port', ev.get('rport', ev.get('synapse_id', None)))
else:
if len(ev) == 2:
receptor, weight = ev
elif len(ev) == 3:
receptor, weight, third = ev
receptor_id = self._normalize_spike_receptor(receptor)
if receptor_id == self.NMDA:
port = third
else:
multiplicity = third
elif len(ev) == 4:
receptor, weight, port, multiplicity = ev
else:
raise ValueError('Spike event tuples must have length 2, 3, or 4.')
receptor_id = self._normalize_spike_receptor(receptor)
weight_np = self._value_to_float(weight, u.nS)
weight_np = np.broadcast_to(weight_np, state_shape)
mult_np = self._value_to_float(multiplicity, None)
mult_np = np.broadcast_to(mult_np, state_shape)
if receptor_id == self.AMPA:
ds_ampa = ds_ampa + weight_np * mult_np
elif receptor_id == self.GABA:
ds_gaba = ds_gaba + weight_np * mult_np
else:
nmda_port = self._normalize_nmda_port(port)
nmda_idx = self._ensure_nmda_port(nmda_port, weight_np, state_shape)
if nmda_idx >= nmda_mult.shape[-1]:
pad = np.zeros(state_shape + (nmda_idx + 1 - nmda_mult.shape[-1],), dtype=dftype)
nmda_mult = np.concatenate([nmda_mult, pad], axis=-1)
nmda_mult[..., nmda_idx] = nmda_mult[..., nmda_idx] + mult_np
return ds_ampa, ds_gaba, nmda_mult
def _parse_registered_spike_inputs(self, state_shape):
dftype = brainstate.environ.dftype()
ds_ampa = np.zeros(state_shape, dtype=dftype)
ds_gaba = np.zeros(state_shape, dtype=dftype)
if self.delta_inputs is None:
return ds_ampa, ds_gaba
for key in tuple(self.delta_inputs.keys()):
val = self.delta_inputs[key]
if callable(val):
val = val()
else:
self.delta_inputs.pop(key)
label = None
if ' // ' in key:
label, _ = key.split(' // ', maxsplit=1)
if label is None:
receptor = self.AMPA
else:
receptor = self._normalize_spike_receptor(label)
if receptor == self.NMDA:
raise ValueError('Use spike_events with NMDA port specification for iaf_bw_2001_exact.')
val_np = self._value_to_float(val, u.nS)
val_np = np.broadcast_to(val_np, state_shape)
if receptor == self.AMPA:
ds_ampa = ds_ampa + val_np
else:
ds_gaba = ds_gaba + val_np
return ds_ampa, ds_gaba
@staticmethod
def _value_to_float(x, unit=None):
r"""Convert quantity with units to float64 NumPy array.
Parameters
----------
x : ArrayLike
Input value, possibly with units.
unit : saiunit.Unit, optional
Target unit for division. If None, return dimensionless float.
Returns
-------
np.ndarray
Float64 array, dimensionless if unit is provided (x / unit), else raw conversion.
"""
dftype = brainstate.environ.dftype()
if unit is None:
return np.asarray(u.math.asarray(x), dtype=dftype)
try:
return np.asarray(u.math.asarray(x / unit), dtype=dftype)
except Exception:
return np.asarray(u.math.asarray(x), dtype=dftype)
@staticmethod
def _broadcast_to_state(x_np: np.ndarray, shape):
r"""Broadcast array to target state shape.
Parameters
----------
x_np : np.ndarray
Input array.
shape : tuple of int
Target shape.
Returns
-------
np.ndarray
Broadcasted view of input array with target shape.
"""
return np.broadcast_to(x_np, shape)
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
The ODE is integrated freely without in-loop V clamping or spike reset.
Spike detection and refractory clamping are applied post-integration
in :meth:`update`, matching NEST's GSL-based integration semantics.
Parameters
----------
state : DotDict
Keys: V, s_AMPA, s_GABA, x_NMDA, s_NMDA_components -- ODE state variables.
extra : DotDict
Keys: unstable, i_stim, nmda_weights -- auxiliary data.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
v_eff = state.V # V evolves freely; no refractory clamping in ODE
# Synaptic currents
i_ampa = state.s_AMPA * (v_eff - self.E_ex)
i_gaba = state.s_GABA * (v_eff - self.E_in)
# NMDA current with Mg2+ blockade
# nmda_weights shape: (*varshape, n_ports), s_NMDA_components shape: (*varshape, n_ports)
s_nmda_sum = u.math.sum(extra.nmda_weights * state.s_NMDA_components, axis=-1)
# Mg2+ voltage-dependent block: denom = 1 + [Mg2+] * exp(-0.062 * V_mV) / 3.57
v_mV = v_eff / u.mV
conc_Mg2_mM = self.conc_Mg2 / u.mM
denom = 1.0 + conc_Mg2_mM * u.math.exp(-0.062 * v_mV) / 3.57
i_nmda = (v_eff - self.E_ex) / denom * s_nmda_sum * u.nS
i_syn = i_ampa + i_gaba + i_nmda
dV = (-self.g_L * (v_eff - self.E_L) - i_syn + extra.i_stim) / self.C_m
ds_AMPA = -state.s_AMPA / self.tau_AMPA
ds_GABA = -state.s_GABA / self.tau_GABA
# NMDA dynamics: dx_j/dt = -x_j / tau_rise_NMDA
# ds_j/dt = -s_j / tau_decay_NMDA + alpha * x_j * (1 - s_j)
# Expand tau/alpha for broadcasting over port dimension
tau_rise = u.math.expand_dims(self.tau_rise_NMDA, axis=-1)
tau_decay = u.math.expand_dims(self.tau_decay_NMDA, axis=-1)
alpha_exp = u.math.expand_dims(self.alpha, axis=-1)
dx_NMDA = -state.x_NMDA / tau_rise
ds_NMDA_components = -state.s_NMDA_components / tau_decay + alpha_exp * state.x_NMDA * (1.0 - state.s_NMDA_components)
return DotDict(
V=dV,
s_AMPA=ds_AMPA,
s_GABA=ds_GABA,
x_NMDA=dx_NMDA,
s_NMDA_components=ds_NMDA_components,
)
def _event_fn(self, state, extra, accept):
"""Track numerical instability only; no in-ODE spike/reset logic.
Spike detection and refractory clamping are handled post-integration
in :meth:`update` to match NEST's semantics (currents recorded from
freely-evolved, pre-reset V).
Parameters
----------
state : DotDict
Keys: V, s_AMPA, s_GABA, x_NMDA, s_NMDA_components.
extra : DotDict
Keys: unstable, i_stim, nmda_weights.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(state, new_extra) -- state is unchanged; extra has updated unstable flag.
"""
unstable = extra.unstable | jnp.any(
accept & (u.get_mantissa(state.V) < -1e3)
)
return state, DotDict({**extra, 'unstable': unstable})
[docs]
def update(self, x=0. * u.pA, spike_events=None):
r"""Advance neuron state by one simulation time step.
Performs RKF45 integration of ODEs, applies spike jumps to conductances,
checks threshold, resets spiking neurons, and updates refractory state.
External current is buffered with one-step delay (NEST compatibility).
Parameters
----------
x : ArrayLike, optional
External input current (pA). Default: 0 pA.
Shape must match ``self.varshape`` or be broadcastable to it.
Summed with registered ``current_inputs`` to form total stimulus.
spike_events : iterable, optional
Collection of synaptic spike events. Default: None (no spikes).
Each event can be a tuple or dict specifying receptor, weight, multiplicity, and port.
**Tuple formats:**
- ``(receptor, weight)``
- ``(receptor, weight, third)`` where ``third`` is multiplicity for AMPA/GABA, port for NMDA
- ``(receptor, weight, port, multiplicity)`` for full NMDA specification
**Dict format:**
- ``receptor_type`` or ``receptor``: int (1/2/3) or str ('AMPA'/'GABA'/'NMDA')
- ``weight``: ArrayLike (nS), synaptic weight
- ``multiplicity``: float, optional (default 1.0)
- ``port`` / ``rport`` / ``synapse_id``: Hashable, optional (required for NMDA)
Returns
-------
ArrayLike
Differentiable spike output for current time step. Shape: ``self.varshape``.
Computed from voltage before reset using ``self.get_spike()``.
Raises
------
ValueError
If attempting to add new NMDA ports after first :meth:`update` call.
ValueError
If NMDA port weight changes after initial registration.
ValueError
If spike event format is invalid.
Notes
-----
**Update sequence (matches NEST ordering):**
1. **RKF45 integration**: Integrate V_m, s_AMPA, s_GABA, x_NMDA, s_NMDA on (t, t+dt]
2. **Spike jumps**: Add to s_AMPA, s_GABA (weight x multiplicity), x_NMDA (multiplicity only)
3. **Threshold check**: If V_m >= V_th and not refractory, emit spike and reset
4. **Refractory update**: Decrement refractory countdown or clamp V_m to V_reset
5. **Buffer stimulus**: Store current input in ``I_stim`` for next step (one-step delay)
**NMDA port constraints:**
- New ports can only be added before first :meth:`update` call
- Port weights are fixed at first registration and cannot change
- Attempting to violate these constraints raises ``ValueError``
**Integration details:**
- Uses adaptive RKF45 with per-neuron step size (not vectorized)
- Local error tolerance controlled by ``gsl_error_tol``
- Minimum step size: 1e-8 ms; maximum iterations: 10,000
- Step size persists across time steps in ``integration_step`` state
**Refractory behavior:**
- During refractory period, V_m is clamped to V_reset
- Refractory countdown decrements each time step
- Threshold check bypassed while refractory
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
state_shape = self.V.value.shape
# Read state variables with their natural units.
V = self.V.value # mV
s_AMPA = self.s_AMPA.value # nS
s_GABA = self.s_GABA.value # nS
x_NMDA = self.x_NMDA.value # dimensionless
s_NMDA_components = self.s_NMDA_components.value # dimensionless
nmda_weights_val = self.nmda_weights.value # nS (dimensionless float)
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
h = self.integration_step.value # ms
# Current input for next step (one-step delay).
new_i_stim = self.sum_current_inputs(x, self.V.value) # pA
# Parse spike events (AMPA/GABA weight deltas and NMDA multiplicities).
ds_ampa_ev, ds_gaba_ev, dx_nmda_ev = self._parse_spike_events(spike_events, state_shape)
ds_ampa_reg, ds_gaba_reg = self._parse_registered_spike_inputs(state_shape)
ds_ampa = ds_ampa_ev + ds_ampa_reg
ds_gaba = ds_gaba_ev + ds_gaba_reg
n_nmda_pre = int(x_NMDA.shape[-1]) if len(x_NMDA.shape) > len(state_shape) else 0
# Re-read n_nmda and weights after parsing; new ports may have been registered.
n_nmda = int(self.x_NMDA.value.shape[-1]) if len(self.x_NMDA.value.shape) > len(state_shape) else 0
nmda_weights_val = self.nmda_weights.value
# If new ports were registered during parsing, expand pre-integration arrays with zeros.
if n_nmda > n_nmda_pre:
n_new = n_nmda - n_nmda_pre
x_NMDA = np.concatenate(
[np.asarray(x_NMDA, dtype=dftype), np.zeros(state_shape + (n_new,), dtype=dftype)], axis=-1
)
s_NMDA_components = np.concatenate(
[np.asarray(s_NMDA_components, dtype=dftype), np.zeros(state_shape + (n_new,), dtype=dftype)], axis=-1
)
if n_nmda > 0 and dx_nmda_ev.shape[-1] != n_nmda:
if dx_nmda_ev.shape[-1] < n_nmda:
pad = np.zeros(state_shape + (n_nmda - dx_nmda_ev.shape[-1],), dtype=dftype)
dx_nmda_ev = np.concatenate([dx_nmda_ev, pad], axis=-1)
else:
dx_nmda_ev = dx_nmda_ev[..., :n_nmda]
# Adaptive RKF45 integration via generic integrator.
# V evolves freely (no in-ODE refractory clamp or spike reset)
# to match NEST's GSL integration semantics.
ode_state = DotDict(
V=V,
s_AMPA=s_AMPA,
s_GABA=s_GABA,
x_NMDA=x_NMDA,
s_NMDA_components=s_NMDA_components,
)
ode_extra = DotDict(
unstable=jnp.array(False),
i_stim=i_stim,
nmda_weights=nmda_weights_val,
)
ode_state, h, ode_extra = self.integrator(state=ode_state, h=h, extra=ode_extra)
V = ode_state.V # freely-evolved post-ODE V (may exceed V_th)
s_AMPA = ode_state.s_AMPA
s_GABA = ode_state.s_GABA
x_NMDA = ode_state.x_NMDA
s_NMDA_components = ode_state.s_NMDA_components
unstable = ode_extra.unstable
# Post-loop stability check.
brainstate.transform.jit_error_if(
jnp.any(unstable), 'Numerical instability in iaf_bw_2001_exact dynamics.'
)
# Compute NMDA weighted sum and synaptic currents for recording.
# Use the freely-evolved post-ODE V (before any spike reset or refractory clamp),
# matching NEST's recording semantics where currents are snapshotted pre-reset.
if n_nmda > 0:
s_nmda_sum = u.math.sum(nmda_weights_val * s_NMDA_components, axis=-1)
else:
s_nmda_sum = u.math.zeros(self.varshape, dtype=dftype)
v_for_current = V # pre-reset, freely-evolved V
i_ampa = s_AMPA * (v_for_current - self.E_ex)
i_gaba = s_GABA * (v_for_current - self.E_in)
v_mV = v_for_current / u.mV
conc_Mg2_mM = self.conc_Mg2 / u.mM
denom = 1.0 + conc_Mg2_mM * u.math.exp(-0.062 * v_mV) / 3.57
i_nmda = (v_for_current - self.E_ex) / denom * s_nmda_sum * u.nS
# Apply synaptic spike inputs (applied after integration).
s_AMPA = s_AMPA + ds_ampa * u.nS
s_GABA = s_GABA + ds_gaba * u.nS
if n_nmda > 0:
x_NMDA = x_NMDA + dx_nmda_ev
# Post-ODE spike detection and refractory handling (matches NEST ordering).
# Refractory neurons: clamp V to V_reset, decrement counter.
is_refractory = r > 0
V = u.math.where(is_refractory, self.V_reset, V)
r = u.math.where(is_refractory, r - 1, r)
# Non-refractory neurons: check threshold, emit spike, reset V, enter refractoriness.
spike_mask = (~is_refractory) & (V >= self.V_th)
V = u.math.where(spike_mask, self.V_reset, V)
r = u.math.where(spike_mask & (self.ref_count > 0), self.ref_count, r)
# Write back state.
self.V.value = V
self.s_AMPA.value = s_AMPA
self.s_GABA.value = s_GABA
self.s_NMDA.value = s_nmda_sum * u.nS
self.x_NMDA.value = x_NMDA
self.s_NMDA_components.value = s_NMDA_components
self.I_AMPA.value = i_ampa
self.I_GABA.value = i_gaba
self.I_NMDA.value = i_nmda
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype)
self.integration_step.value = h
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
self._updates_started = True
return u.math.asarray(spike_mask, dtype=dftype)