iaf_bw_2001_exact#

class brainpy.state.iaf_bw_2001_exact(in_size, E_L=Quantity(-70., "mV"), E_ex=Quantity(0., "mV"), E_in=Quantity(-70., "mV"), V_th=Quantity(-55., "mV"), V_reset=Quantity(-60., "mV"), C_m=Quantity(500., "pF"), g_L=Quantity(25., "nS"), t_ref=Quantity(2., "ms"), tau_AMPA=Quantity(2., "ms"), tau_GABA=Quantity(5., "ms"), tau_rise_NMDA=Quantity(2., "ms"), tau_decay_NMDA=Quantity(100., "ms"), alpha=Quantity(0.5, "kHz"), conc_Mg2=Quantity(1., "mM"), gsl_error_tol=0.001, V_initializer=Constant(value=-70. mV), s_AMPA_initializer=Constant(value=0. nS), s_GABA_initializer=Constant(value=0. nS), spk_fun=ReluGrad(alpha=0.3, width=1.0), spk_reset='hard', ref_var=False, name=None)#

NEST-compatible conductance-based LIF neuron with exact per-synapse NMDA dynamics.

This model implements the Brunel-Wang (2001) neuron with exact NMDA kinetics, maintaining separate rise and decay variables for each NMDA synapse without presynaptic-jump approximation. Each NMDA connection is assigned a unique port with a fixed weight, enforcing NEST’s constraint that NMDA connections cannot be added after the first simulation step.

Parameters:
  • in_size (int, tuple of int, Sequence of int) – Population shape. Defines the number and arrangement of neurons.

  • E_L (ArrayLike, optional) – Leak reversal potential. Default: -70 mV. Determines the resting potential in the absence of input.

  • E_ex (ArrayLike, optional) – Excitatory reversal potential. Default: 0 mV. Reversal potential for AMPA and NMDA receptors.

  • E_in (ArrayLike, optional) – Inhibitory reversal potential. Default: -70 mV. Reversal potential for GABA receptors.

  • V_th (ArrayLike, optional) – Spike threshold potential. Default: -55 mV. Membrane potential at which a spike is emitted.

  • V_reset (ArrayLike, optional) – Reset potential. Default: -60 mV. Membrane potential immediately after spike emission. Must be < V_th.

  • C_m (ArrayLike, optional) – Membrane capacitance. Default: 500 pF. Must be strictly positive.

  • g_L (ArrayLike, optional) – Leak conductance. Default: 25 nS. Conductance through passive leak channels.

  • t_ref (ArrayLike, optional) – Absolute refractory period duration. Default: 2 ms. Time after spike during which membrane is clamped to V_reset.

  • tau_AMPA (ArrayLike, optional) – AMPA decay time constant. Default: 2 ms. Governs exponential decay of AMPA conductance. Must be > 0.

  • tau_GABA (ArrayLike, optional) – GABA decay time constant. Default: 5 ms. Governs exponential decay of GABA conductance. Must be > 0.

  • tau_rise_NMDA (ArrayLike, optional) – NMDA rise time constant. Default: 2 ms. Time constant for NMDA activation variable x_j. Must be > 0.

  • tau_decay_NMDA (ArrayLike, optional) – NMDA decay time constant. Default: 100 ms. Time constant for NMDA gating variable s_j. Must be > 0.

  • alpha (ArrayLike, optional) – NMDA rise coupling strength. Default: 0.5 / ms. Scales the coupling between rise (x_j) and gating (s_j) variables. Must be > 0.

  • conc_Mg2 (ArrayLike, optional) – Extracellular magnesium concentration. Default: 1 mM. Controls voltage-dependent NMDA blockade. Must be > 0.

  • gsl_error_tol (ArrayLike, optional) – RKF45 local error tolerance. Default: 1e-3. Controls adaptive step size in Runge-Kutta-Fehlberg integration. Must be > 0. Smaller values improve accuracy at the cost of more iterations.

  • V_initializer (Callable, optional) – Membrane potential initializer. Default: Constant(-70 mV). Function that generates initial V_m values.

  • s_AMPA_initializer (Callable, optional) – AMPA conductance state initializer. Default: Constant(0 nS). Function that generates initial s_AMPA values.

  • s_GABA_initializer (Callable, optional) – GABA conductance state initializer. Default: Constant(0 nS). Function that generates initial s_GABA values.

  • spk_fun (Callable, optional) – Surrogate gradient function for spike generation. Default: ReluGrad(). Maps scaled voltage to differentiable spike output.

  • spk_reset (str, optional) – Spike reset mode. Default: ‘hard’. - ‘hard’: Stop gradient through reset (matches NEST) - ‘soft’: Gradient flows through reset (V -= V_th)

  • ref_var (bool, optional) – If True, expose boolean refractory state variable. Default: False. Adds a refractory attribute for monitoring refractory state.

  • name (str, optional) – Module name. Default: None (auto-generated).

Raises:
  • ValueError – If V_reset >= V_th, or any of C_m, tau_*, alpha, conc_Mg2, gsl_error_tol <= 0.

  • ValueError – If attempting to change NMDA port weights after first registration.

  • ValueError – If attempting to add new NMDA ports after first update() call.

  • ValueError – If NMDA port is not hashable.

  • ValueError – If spike event format is invalid.

See also

iaf_bw_2001

Approximate version using presynaptic-jump NMDA dynamics

iaf_cond_exp

Simpler conductance-based LIF without NMDA

aeif_cond_alpha

Adaptive exponential IF with alpha-shaped conductances

Parameter Mapping

NEST Parameter

brainpy.state

Notes

E_L

E_L

Leak reversal potential (mV)

E_ex

E_ex

Excitatory reversal (mV)

E_in

E_in

Inhibitory reversal (mV)

V_th

V_th

Spike threshold (mV)

V_reset

V_reset

Reset potential (mV)

C_m

C_m

Membrane capacitance (pF)

g_L

g_L

Leak conductance (nS)

t_ref

t_ref

Refractory period (ms)

tau_AMPA

tau_AMPA

AMPA decay time (ms)

tau_GABA

tau_GABA

GABA decay time (ms)

tau_rise_NMDA

tau_rise_NMDA

NMDA rise time (ms)

tau_decay_NMDA

tau_decay_NMDA

NMDA decay time (ms)

alpha

alpha

NMDA coupling (1/ms)

conc_Mg2

conc_Mg2

Mg2+ concentration (mM)

gsl_error_tol

gsl_error_tol

RKF45 tolerance (dimensionless)

Mathematical Model

1. Membrane Dynamics

The subthreshold membrane potential evolves according to:

\[C_m \frac{dV_m}{dt} = -g_L(V_m - E_L) - I_{syn} + I_{stim}\]

where \(I_{syn} = I_{AMPA} + I_{GABA} + I_{NMDA}\) is the total synaptic current.

2. Synaptic Currents

AMPA and GABA currents are ohmic:

\[\begin{split}I_{AMPA} &= (V_m - E_{ex}) s_{AMPA} \\ I_{GABA} &= (V_m - E_{in}) s_{GABA}\end{split}\]

NMDA current includes voltage-dependent Mg2+ blockade:

\[I_{NMDA} = \frac{(V_m - E_{ex})}{1 + [Mg^{2+}]\exp(-0.062V_m)/3.57} \sum_j w_j s_j\]

where \(j\) indexes individual NMDA synapses, \(w_j\) is the fixed weight for port \(j\), and \(s_j\) is the gating variable for that synapse.

3. Synaptic Gating Variables

AMPA and GABA conductances decay exponentially:

\[\begin{split}\frac{ds_{AMPA}}{dt} &= -\frac{s_{AMPA}}{\tau_{AMPA}} \\ \frac{ds_{GABA}}{dt} &= -\frac{s_{GABA}}{\tau_{GABA}}\end{split}\]

Each NMDA synapse \(j\) has dual-timescale kinetics:

\[\begin{split}\frac{dx_j}{dt} &= -\frac{x_j}{\tau_{NMDA,rise}} \\ \frac{ds_j}{dt} &= -\frac{s_j}{\tau_{NMDA,decay}} + \alpha x_j (1-s_j)\end{split}\]

where \(x_j\) is the rise variable (fast activation) and \(s_j\) is the gating variable (slow inactivation with saturation).

4. Spike Generation and Reset

When \(V_m \geq V_{th}\) and the neuron is not refractory:

  • Emit a spike

  • Set \(V_m \leftarrow V_{reset}\)

  • Enter refractory state for \(t_{ref}\) ms

During refractoriness, \(V_m\) is clamped to \(V_{reset}\).

5. Numerical Integration

The continuous dynamics are integrated using adaptive Runge-Kutta-Fehlberg (RKF45) with:

  • 4th and 5th order embedded methods for error estimation

  • Persistent step size \(h\) that adapts to maintain local error < gsl_error_tol

  • Minimum step size \(h_{min} = 10^{-8}\) ms

  • Maximum iterations per simulation step: 10,000

NMDA Port Semantics

NEST assigns each NMDA connection a unique receptor port at connect time and prohibits adding new NMDA connections after the first simulation step. This implementation mirrors that behavior:

  • Each NMDA event requires a port identifier (any hashable value)

  • The first event for a new port registers that port with the provided weight

  • Subsequent events to the same port must use the same weight (enforced)

  • New ports can only be added before the first update() call

  • AMPA/GABA events do not use ports (weights accumulate directly)

Spike Event Formats

The spike_events parameter accepts multiple formats:

Tuple formats:

  • (receptor, weight) — receptor in {1, 2, 3} or {‘AMPA’, ‘GABA’, ‘NMDA’}

  • (receptor, weight, third)third is multiplicity for AMPA/GABA, port for NMDA

  • (receptor, weight, port, multiplicity) — full NMDA specification

Dict format:

  • Required keys: receptor_type or receptor (1/2/3 or ‘AMPA’/’GABA’/’NMDA’), weight

  • Optional keys: multiplicity (default 1.0), port/rport/synapse_id (for NMDA)

Update Ordering (matches NEST)

Each update() call executes in this order:

  1. Integrate ODEs on \((t, t+dt]\) using RKF45 with persistent step size

  2. Apply spike jumps: add to s_AMPA, s_GABA, and x_j for each NMDA port

  3. Threshold check and reset: emit spikes, reset voltage, update refractory countdown

  4. Store external current: buffer I_stim for next step (one-step delay)

Recordable Variables

  • V_m — Membrane potential (mV)

  • s_AMPA — AMPA conductance state (nS)

  • s_GABA — GABA conductance state (nS)

  • s_NMDA — Weighted sum of NMDA gating variables (nS), \(\sum_j w_j s_j\)

  • I_AMPA — AMPA current (pA)

  • I_GABA — GABA current (pA)

  • I_NMDA — NMDA current (pA)

Additional State Variables

  • x_NMDA — NMDA rise variables for each port (shape: [*in_size, n_ports])

  • s_NMDA_components — NMDA gating variables for each port (shape: [*in_size, n_ports])

  • nmda_weights — Fixed weights for each NMDA port (shape: [*in_size, n_ports])

  • refractory_step_count — Remaining refractory steps (int32)

  • integration_step — Persistent RKF45 step size (ms)

  • I_stim — One-step delayed external current buffer (pA)

  • refractory — Boolean refractory indicator (only if ref_var=True)

Performance Considerations:

  • RKF45 integration is performed per-neuron in NumPy (not vectorized)

  • Computational cost scales linearly with the number of NMDA ports

  • Large gsl_error_tol reduces accuracy but improves speed

  • This model is significantly slower than iaf_bw_2001 due to per-synapse state

Comparison to iaf_bw_2001:

  • iaf_bw_2001 approximates all NMDA synapses with a single pair of state variables

  • iaf_bw_2001_exact tracks rise and decay for each NMDA connection separately

  • Use iaf_bw_2001_exact when NMDA synapse heterogeneity matters (e.g., detailed working memory models)

  • Use iaf_bw_2001 for large-scale simulations where approximation is acceptable

References

Examples

Basic usage with AMPA input:

>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=10)
>>> net.init_all_states()
>>> # Apply AMPA input spike
>>> spike = bp.iaf_bw_2001_exact.get_spike(net(spike_events=[(1, 100*u.nS)]))
>>> print(net.V.value)

NMDA connections with unique ports:

>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=5)
>>> net.init_all_states()
>>> # Register two NMDA ports with different weights
>>> events = [
...     (3, 50*u.nS, 'port_A', 1.0),  # NMDA port A, weight 50 nS
...     (3, 75*u.nS, 'port_B', 1.0),  # NMDA port B, weight 75 nS
... ]
>>> spike = net(spike_events=events)
>>> print(net.s_NMDA_components.value.shape)
(5, 2)  # 5 neurons x 2 NMDA ports

Mixing AMPA, GABA, and NMDA:

>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=1, V_th=-50*u.mV)
>>> net.init_all_states()
>>> events = [
...     {'receptor': 'AMPA', 'weight': 200*u.nS, 'multiplicity': 2.0},
...     {'receptor': 'GABA', 'weight': 100*u.nS},
...     {'receptor': 'NMDA', 'weight': 50*u.nS, 'port': 0},
... ]
>>> for _ in range(100):
...     spike = net(spike_events=events if _ == 10 else None)
>>> print(net.last_spike_time.value)

Monitoring refractory state:

>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.context(dt=0.1 * u.ms)
>>> net = bp.iaf_bw_2001_exact(in_size=3, ref_var=True, t_ref=5*u.ms)
>>> net.init_all_states()
>>> net.V.value = net.V_th + 1*u.mV  # Force spike
>>> spike = net()
>>> print(net.refractory.value)
[True True True]
get_spike(V=None)[source]#

Generate differentiable spike output from membrane potential.

Scales voltage relative to threshold and applies surrogate gradient function for gradient-based learning. Voltage is scaled linearly between V_reset (0) and V_th (1).

Parameters:

V (ArrayLike, optional) – Membrane potential (mV). Default: None (uses current self.V.value). Shape must match self.varshape or be broadcastable to it.

Returns:

Differentiable spike output in [0, 1]. Shape matches input voltage. Values close to 1 indicate spiking; values close to 0 indicate quiescence. Exact output depends on self.spk_fun (e.g., ReLU, sigmoid, etc.).

Return type:

ArrayLike

Notes

  • Used internally during update() to compute spike output before reset

  • Scaling formula: \(v_{scaled} = (V - V_{th}) / (V_{th} - V_{reset})\)

  • For hard reset mode, actual spike detection uses \(V \geq V_{th}\)

init_state(**kwargs)[source]#

Initialize all state variables.

Creates and initializes membrane potential, synaptic conductances, currents, NMDA port arrays (initially empty), refractory state, and integration step size. NMDA port registry is cleared.

Parameters:

**kwargs – Unused compatibility parameters accepted by the base-state API.

Notes

  • NMDA port arrays (x_NMDA, s_NMDA_components, nmda_weights) start empty (shape: […, 0])

  • Ports are allocated dynamically when first NMDA spike arrives

  • Clears the internal _nmda_port_index registry

  • Resets _updates_started flag to False

property receptor_types#

Mapping of receptor names to numeric identifiers.

Returns:

Dictionary mapping {‘AMPA’: 1, ‘GABA’: 2, ‘NMDA’: 3}.

Return type:

dict

property recordables#

List of variables available for recording.

Returns:

[‘V_m’, ‘s_AMPA’, ‘s_GABA’, ‘s_NMDA’, ‘I_NMDA’, ‘I_AMPA’, ‘I_GABA’].

Return type:

list of str

reset_state(batch_size=None, **kwargs)[source]#

Reset all state variables to initial values.

Unlike init_state(), this preserves NMDA port structure (number of ports and their weights remain unchanged). Resets voltage, conductances, currents, NMDA gating variables, refractory state, and integration step size.

Parameters:
  • batch_size (int, optional) – Batch dimension size for state variables. Default: None (no batching). If provided, reshapes state variables with a leading batch dimension.

  • **kwargs – Additional keyword arguments (currently unused).

Notes

  • NMDA port count and weights are preserved (but x_NMDA and s_NMDA_components are zeroed)

  • Does NOT clear _nmda_port_index (port registry persists)

  • Does NOT reset _updates_started flag

update(x=Quantity(0., 'pA'), spike_events=None)[source]#

Advance neuron state by one simulation time step.

Performs RKF45 integration of ODEs, applies spike jumps to conductances, checks threshold, resets spiking neurons, and updates refractory state. External current is buffered with one-step delay (NEST compatibility).

Parameters:
  • x (ArrayLike, optional) – External input current (pA). Default: 0 pA. Shape must match self.varshape or be broadcastable to it. Summed with registered current_inputs to form total stimulus.

  • spike_events (iterable, optional) –

    Collection of synaptic spike events. Default: None (no spikes). Each event can be a tuple or dict specifying receptor, weight, multiplicity, and port.

    Tuple formats:

    • (receptor, weight)

    • (receptor, weight, third) where third is multiplicity for AMPA/GABA, port for NMDA

    • (receptor, weight, port, multiplicity) for full NMDA specification

    Dict format:

    • receptor_type or receptor: int (1/2/3) or str (‘AMPA’/’GABA’/’NMDA’)

    • weight: ArrayLike (nS), synaptic weight

    • multiplicity: float, optional (default 1.0)

    • port / rport / synapse_id: Hashable, optional (required for NMDA)

Returns:

Differentiable spike output for current time step. Shape: self.varshape. Computed from voltage before reset using self.get_spike().

Return type:

ArrayLike

Raises:
  • ValueError – If attempting to add new NMDA ports after first update() call.

  • ValueError – If NMDA port weight changes after initial registration.

  • ValueError – If spike event format is invalid.

Notes

Update sequence (matches NEST ordering):

  1. RKF45 integration: Integrate V_m, s_AMPA, s_GABA, x_NMDA, s_NMDA on (t, t+dt]

  2. Spike jumps: Add to s_AMPA, s_GABA (weight x multiplicity), x_NMDA (multiplicity only)

  3. Threshold check: If V_m >= V_th and not refractory, emit spike and reset

  4. Refractory update: Decrement refractory countdown or clamp V_m to V_reset

  5. Buffer stimulus: Store current input in I_stim for next step (one-step delay)

NMDA port constraints:

  • New ports can only be added before first update() call

  • Port weights are fixed at first registration and cannot change

  • Attempting to violate these constraints raises ValueError

Integration details:

  • Uses adaptive RKF45 with per-neuron step size (not vectorized)

  • Local error tolerance controlled by gsl_error_tol

  • Minimum step size: 1e-8 ms; maximum iterations: 10,000

  • Step size persists across time steps in integration_step state

Refractory behavior:

  • During refractory period, V_m is clamped to V_reset

  • Refractory countdown decrements each time step

  • Threshold check bypassed while refractory