iaf_bw_2001#

class brainpy.state.iaf_bw_2001(in_size, E_L=Quantity(-70., "mV"), E_ex=Quantity(0., "mV"), E_in=Quantity(-70., "mV"), V_th=Quantity(-55., "mV"), V_reset=Quantity(-60., "mV"), C_m=Quantity(500., "pF"), g_L=Quantity(25., "nS"), t_ref=Quantity(2., "ms"), tau_AMPA=Quantity(2., "ms"), tau_GABA=Quantity(5., "ms"), tau_decay_NMDA=Quantity(100., "ms"), tau_rise_NMDA=Quantity(2., "ms"), alpha=Quantity(0.5, "kHz"), conc_Mg2=Quantity(1., "mM"), gsl_error_tol=0.001, V_initializer=Constant(value=-70. mV), s_AMPA_initializer=Constant(value=0. nS), s_GABA_initializer=Constant(value=0. nS), s_NMDA_initializer=Constant(value=0. nS), spk_fun=ReluGrad(alpha=0.3, width=1.0), spk_reset='hard', ref_var=False, name=None)#

NEST-compatible iaf_bw_2001 neuron model.

Conductance-based leaky integrate-and-fire neuron with AMPA, GABA, and approximate NMDA synaptic dynamics from Brunel-Wang style cortical models.

This model implements the NEST iaf_bw_2001 neuron with full compatibility, including adaptive RKF45 integration of subthreshold ODEs, receptor-routed AMPA/GABA/NMDA spike inputs, one-step delayed external current buffering, refractory countdown and reset ordering, and NMDA presynaptic jump approximation using spike-event offsets.

1. Mathematical Model

The continuous-time state vector is:

\[y = (V_m, s_{AMPA}, s_{GABA}, s_{NMDA}).\]

Membrane dynamics:

\[C_m \frac{dV_m}{dt} = -g_L(V_m - E_L) - I_{syn} + I_{stim},\]

where the total synaptic current is:

\[I_{syn} = I_{AMPA} + I_{GABA} + I_{NMDA}.\]

Synaptic currents:

AMPA and GABA currents use Ohmic conductance:

\[I_{AMPA} = (V_m - E_{ex}) s_{AMPA}, \quad I_{GABA} = (V_m - E_{in}) s_{GABA}.\]

NMDA current includes voltage-dependent Mg²⁺ block:

\[I_{NMDA} = \frac{(V_m - E_{ex}) s_{NMDA}} {1 + [Mg^{2+}]\exp(-0.062 V_m)/3.57}.\]

Synaptic kinetics:

All three receptor types decay exponentially:

\[\frac{ds_{AMPA}}{dt} = -\frac{s_{AMPA}}{\tau_{AMPA}}, \quad \frac{ds_{GABA}}{dt} = -\frac{s_{GABA}}{\tau_{GABA}}, \quad \frac{ds_{NMDA}}{dt} = -\frac{s_{NMDA}}{\tau_{NMDA,decay}}.\]

2. NMDA Approximation and Spike Offsets

NMDA recurrent coupling uses a presynaptic auxiliary variable s_NMDA_pre updated only when this neuron spikes. At spike time \(t_{spike}\):

\[s_{pre} \leftarrow s_{pre} \exp\left(-\frac{t_{spike} - t_{last}}{\tau_{NMDA,decay}}\right),\]
\[\Delta s_{NMDA} = k_0 + k_1 s_{pre}, \quad s_{pre} \leftarrow s_{pre} + \Delta s_{NMDA},\]

where the jump constants are:

\[k_1 = \exp(-\alpha\tau_{NMDA,rise}) - 1,\]
\[k_0 = (\alpha\tau_{NMDA,rise})^{\tau_{NMDA,rise}/\tau_{NMDA,decay}} \gamma\Big(1 - \tau_{NMDA,rise}/\tau_{NMDA,decay}, \alpha\tau_{NMDA,rise}\Big),\]

where \(\gamma\) is the lower incomplete gamma function. The per-spike \(\Delta s_{NMDA}\) is exposed as spike_offset and used by NMDA receptor events as weight * spike_offset (matching NEST SpikeEvent semantics for iaf_bw_2001).

3. Update Order (NEST Semantics)

Per simulation step:

  1. Integration: Integrate ODEs on \((t, t+dt]\) using adaptive Runge-Kutta-Fehlberg 4(5) with persistent internal step size.

  2. Spike reception: Add arriving AMPA/GABA/NMDA spike increments to s_AMPA, s_GABA, s_NMDA.

  3. Threshold/reset: Apply refractory countdown or check threshold, emit spike, and reset if \(V_m \geq V_{th}\).

  4. Current buffering: Store external current into delayed buffer I_stim for next step (one-step ring-buffer delay).

Ordering notes:

  • Refractory clamping is applied after integration (as in NEST source).

  • I_stim uses one-step delay to match NEST’s ring-buffer semantics.

  • During refractory period, \(V_m\) is clamped to \(V_{reset}\).

4. Receptor Types and Event Semantics

Receptor types (matching NEST names and IDs):

  • AMPA = 1 (excitatory, fast)

  • GABA = 2 (inhibitory)

  • NMDA = 3 (excitatory, slow, voltage-dependent)

The spike_events parameter passed to update() may contain tuples or dictionaries:

  • Tuple format: (receptor, weight) or (receptor, weight, offset) or (receptor, weight, offset, sender_model)

  • Dict format: {'receptor_type': ..., 'weight': ..., 'offset': ..., 'sender_model': ...}

For NMDA events, sender_model must be 'iaf_bw_2001'; otherwise a ValueError is raised (mirroring NEST’s illegal-connection check, as only iaf_bw_2001 neurons compute the NMDA spike offset).

Registered add_delta_input entries can be receptor-labeled using label='AMPA', label='GABA', or label='NMDA'. Unlabeled delta inputs default to AMPA.

Parameters:
  • in_size (int, tuple of int) – Population shape (number of neurons). Can be an integer or tuple for multi-dimensional populations.

  • E_L (saiunit.Quantity, optional) – Leak reversal potential. Default: -70 mV.

  • E_ex (saiunit.Quantity, optional) – Excitatory reversal potential (AMPA, NMDA). Default: 0 mV.

  • E_in (saiunit.Quantity, optional) – Inhibitory reversal potential (GABA). Default: -70 mV.

  • V_th (saiunit.Quantity, optional) – Spike threshold potential. Default: -55 mV.

  • V_reset (saiunit.Quantity, optional) – Reset potential after spike. Must be strictly less than V_th. Default: -60 mV.

  • C_m (saiunit.Quantity, optional) – Membrane capacitance. Must be strictly positive. Default: 500 pF.

  • g_L (saiunit.Quantity, optional) – Leak conductance. Default: 25 nS.

  • t_ref (saiunit.Quantity, optional) – Absolute refractory period duration. Must be non-negative. Default: 2 ms.

  • tau_AMPA (saiunit.Quantity, optional) – AMPA receptor decay time constant. Must be strictly positive. Default: 2 ms.

  • tau_GABA (saiunit.Quantity, optional) – GABA receptor decay time constant. Must be strictly positive. Default: 5 ms.

  • tau_decay_NMDA (saiunit.Quantity, optional) – NMDA receptor slow decay time constant. Must be strictly positive. Default: 100 ms.

  • tau_rise_NMDA (saiunit.Quantity, optional) – NMDA receptor fast rise time constant for jump approximation. Must be strictly positive. Default: 2 ms.

  • alpha (saiunit.Quantity, optional) – NMDA jump-shape parameter (rate constant). Must be strictly positive. Default: 0.5 / ms.

  • conc_Mg2 (saiunit.Quantity, optional) – Extracellular magnesium concentration for NMDA voltage-dependent block. Must be strictly positive. Default: 1 mM.

  • gsl_error_tol (float, optional) – RKF45 local error tolerance (analog to NEST’s gsl_error_tol). Smaller values increase integration accuracy but decrease performance. Must be strictly positive. Default: 1e-3.

  • V_initializer (callable, optional) – Membrane potential initializer function. Default: Constant(-70 mV).

  • s_AMPA_initializer (callable, optional) – AMPA conductance state initializer. Default: Constant(0 nS).

  • s_GABA_initializer (callable, optional) – GABA conductance state initializer. Default: Constant(0 nS).

  • s_NMDA_initializer (callable, optional) – NMDA conductance state initializer. Default: Constant(0 nS).

  • spk_fun (callable, optional) – Surrogate gradient function for spike generation. Default: ReluGrad().

  • spk_reset (str, optional) – Spike reset mode. 'hard' (stop gradient) matches NEST behavior; 'soft' (subtract threshold) is differentiable. Default: ‘hard’.

  • ref_var (bool, optional) – If True, expose boolean refractory state variable. Default: False.

  • name (str, optional) – Name of the neuron group.

Parameter Mapping

The following table maps brainpy.state parameter names to their NEST equivalents:

brainpy.state

NEST

Description

E_L

E_L

Leak reversal potential

E_ex

E_ex

Excitatory reversal potential

E_in

E_in

Inhibitory reversal potential

V_th

V_th

Spike threshold

V_reset

V_reset

Reset potential

C_m

C_m

Membrane capacitance

g_L

g_L

Leak conductance

t_ref

t_ref

Refractory period

tau_AMPA

tau_AMPA

AMPA decay time constant

tau_GABA

tau_GABA

GABA decay time constant

tau_decay_NMDA

tau_decay_NMDA

NMDA slow decay time constant

tau_rise_NMDA

tau_rise_NMDA

NMDA fast rise time constant

alpha

alpha

NMDA jump-shape parameter

conc_Mg2

conc_Mg2

Extracellular Mg²⁺ concentration

gsl_error_tol

gsl_error_tol

RKF45 error tolerance

Recordables

The following state variables can be recorded during simulation:

  • V_m : membrane potential (mV)

  • s_AMPA : AMPA conductance state (nS)

  • s_GABA : GABA conductance state (nS)

  • s_NMDA : NMDA conductance state (nS)

  • I_AMPA : AMPA synaptic current (pA)

  • I_GABA : GABA synaptic current (pA)

  • I_NMDA : NMDA synaptic current (pA)

Additional State Variables

The following internal state variables are maintained but typically not recorded:

  • s_NMDA_pre : presynaptic NMDA helper state (unitless)

  • spike_offset : per-step NMDA offset emitted on spike (unitless)

  • refractory_step_count : absolute refractory countdown (int)

  • integration_step : persistent adaptive RKF45 step size (ms)

  • I_stim : one-step delayed external current buffer (pA)

  • last_spike_time : time of last spike (ms)

  • refractory : boolean refractory indicator (only if ref_var=True)

Raises:
  • ValueError – If V_reset >= V_th (reset must be below threshold).

  • ValueError – If C_m <= 0 (capacitance must be positive).

  • ValueError – If t_ref < 0 (refractory period cannot be negative).

  • ValueError – If any time constant (tau_AMPA, tau_GABA, tau_decay_NMDA, tau_rise_NMDA) is non-positive.

  • ValueError – If alpha <= 0 (NMDA shape parameter must be positive).

  • ValueError – If conc_Mg2 <= 0 (Mg²⁺ concentration must be positive).

  • ValueError – If gsl_error_tol <= 0 (error tolerance must be positive).

  • ValueError – If NMDA spike event has sender_model != 'iaf_bw_2001' (only iaf_bw_2001 neurons can compute NMDA spike offsets).

Examples

Create a simple network with AMPA and NMDA recurrent connections:

>>> import brainpy.state as bst
>>> import saiunit as u
>>> import brainstate
>>>
>>> # Create neuron population
>>> neurons = bst.iaf_bw_2001(100, V_th=-50*u.mV, t_ref=3*u.ms)
>>>
>>> # Initialize states
>>> with brainstate.environ.context(dt=0.1*u.ms):
...     neurons.init_all_states()
>>>
>>> # Simulate with external input
>>> with brainstate.environ.context(dt=0.1*u.ms, t=0*u.ms):
...     spike = neurons(x=500*u.pA)  # External current input

Simulate with explicit spike events (receptor-routed):

>>> # AMPA spike event (tuple format)
>>> ampa_event = ('AMPA', 1.0*u.nS)
>>> spike = neurons(spike_events=[ampa_event])
>>>
>>> # NMDA spike event (dict format with offset)
>>> nmda_event = {
...     'receptor_type': 'NMDA',
...     'weight': 0.5*u.nS,
...     'offset': 0.8,  # Presynaptic NMDA offset from sender
...     'sender_model': 'iaf_bw_2001'
... }
>>> spike = neurons(spike_events=[nmda_event])

Notes

  • Integration method: This model uses adaptive Runge-Kutta-Fehlberg 4(5) (RKF45) with local error control, matching NEST’s GSL integration. The internal step size integration_step is persistent and adapted per neuron.

  • NMDA offset computation: Only iaf_bw_2001 neurons compute the NMDA spike offset. If connecting other neuron types, NMDA connections will raise a ValueError. Use AMPA for inter-model connectivity.

  • Surrogate gradients: Unlike NEST (which is not differentiable), this implementation supports gradient-based learning via surrogate spike functions.

  • Performance: RKF45 integration is accurate but slow for large populations. For performance-critical applications, consider using fixed-step models (e.g., iaf_cond_exp, iaf_psc_alpha) when NMDA dynamics are not required.

  • Refractory semantics: During refractory period, \(V_m\) is clamped to \(V_{reset}\), and threshold crossing is disabled. This matches NEST behavior.

References

See also

iaf_cond_exp

Simpler conductance-based LIF without NMDA dynamics.

iaf_psc_alpha

Current-based LIF with alpha-function PSCs.

iaf_bw_2001_exact

Exact integration variant (if available).

get_spike(V=None)[source]#

Compute spike output using surrogate gradient function.

Converts membrane potential to a differentiable spike signal using the configured surrogate gradient function (spk_fun). The membrane potential is scaled relative to threshold and reset before applying the surrogate.

Parameters:

V (ArrayLike, optional) – Membrane potential (mV). If None, uses current self.V.value. Shape: (*in_size,) or (batch_size, *in_size).

Returns:

Spike signal (differentiable). Shape matches input V. Values in [0, 1] for typical surrogate functions (e.g., sigmoid-based). Hard thresholding (Heaviside) gives binary {0, 1} values.

Return type:

jax.numpy.ndarray

Notes

  • Scaling factor: \((V - V_{th}) / (V_{th} - V_{reset})\).

  • The surrogate function is differentiable during backpropagation but appears as a step function during forward pass (for gradient flow).

  • This method is called internally by update() after integration and threshold checking.

init_state(**kwargs)[source]#

Initialize all state variables for the neuron population.

Creates and initializes membrane potential, synaptic conductance states (AMPA, GABA, NMDA), synaptic currents, refractory counters, NMDA presynaptic helper state, adaptive RKF45 step size, and delayed current buffer.

Parameters:

**kwargs – Unused compatibility parameters accepted by the base-state API.

Notes

  • All synaptic conductances initialize to 0 nS by default.

  • Membrane potential initializes to -70 mV (near E_L) by default.

  • integration_step initializes to the simulation timestep dt.

  • last_spike_time initializes to -1e7 ms (far in the past).

  • If ref_var=True, a boolean refractory state is also created.

property receptor_types#

Return dictionary of available receptor types.

Returns:

Mapping from receptor name (str) to receptor ID (int). Keys: 'AMPA', 'GABA', 'NMDA'. Values: 1, 2, 3.

Return type:

dict

property recordables#

Return list of recordable state variable names.

Returns:

State variables that can be recorded during simulation: ['V_m', 's_AMPA', 's_GABA', 's_NMDA', 'I_NMDA', 'I_AMPA', 'I_GABA'].

Return type:

list of str

update(x=Quantity(0., 'pA'), spike_events=None)[source]#

Advance the neuron state by one simulation timestep.

Performs a complete update cycle including: (1) RKF45 integration of ODEs, (2) reception of AMPA/GABA/NMDA spike events, (3) threshold detection and spike emission, (4) refractory period handling, (5) NMDA spike offset computation, and (6) delayed current buffering.

Parameters:
  • x (saiunit.Quantity, optional) – External input current (pA). Can be scalar or array matching population shape. This current is buffered and applied in the next timestep (one-step delay, matching NEST ring-buffer semantics). Default: 0 pA.

  • spike_events (list of tuple or dict, optional) –

    Incoming spike events from presynaptic neurons. Each event can be:

    • Tuple: (receptor, weight) or (receptor, weight, offset) or (receptor, weight, offset, sender_model)

    • Dict: {'receptor_type': ..., 'weight': ..., 'offset': ..., 'sender_model': ...}

    Receptor types: 'AMPA' or 1, 'GABA' or 2, 'NMDA' or 3. Weight units: nS (conductance). Offset (for NMDA only): presynaptic NMDA spike offset (unitless, default 1.0). Sender model (for NMDA only): must be 'iaf_bw_2001'.

    If None, no spike events are processed. Default: None.

Returns:

Spike output (differentiable). Shape: (*in_size,). Values in [0, 1] for typical surrogate functions.

Return type:

jax.numpy.ndarray

Raises:

ValueError – If an NMDA spike event has sender_model != 'iaf_bw_2001'. Only iaf_bw_2001 neurons compute NMDA spike offsets; other neuron types cannot send NMDA spikes to this model.

Notes

Update order (matching NEST):

  1. Integration: Integrate ODEs using adaptive RKF45 from \(t\) to \(t + dt\). The persistent integration_step is adapted per neuron based on local error.

  2. Spike reception: Add incoming spike weights (scaled by offset for NMDA) to s_AMPA, s_GABA, s_NMDA.

  3. Refractory/threshold:

    • If in refractory period (refractory_step_count > 0): clamp \(V_m\) to \(V_{reset}\), decrement counter.

    • Else: check threshold \(V_m \geq V_{th}\). If crossed, emit spike, reset \(V_m \leftarrow V_{reset}\), set refractory counter, compute NMDA spike offset.

  4. Current buffering: Store input current x (plus any registered current inputs) into I_stim buffer for next step.

NMDA spike offset computation:

When this neuron spikes, the NMDA spike offset \(\Delta s_{NMDA}\) is computed using the presynaptic helper state s_NMDA_pre:

\[s_{pre} \leftarrow s_{pre} \exp(-\Delta t / \tau_{NMDA,decay}),\]
\[\Delta s_{NMDA} = k_0 + k_1 s_{pre},\]

where \(\Delta t = t_{spike} - t_{last}\) and \(k_0, k_1\) are precomputed constants. The updated s_NMDA_pre is stored for the next spike. The offset \(\Delta s_{NMDA}\) is exposed as spike_offset and should be passed to downstream NMDA connections.

Current delay:

The external current x is stored in I_stim and applied in the next timestep. This one-step delay matches NEST’s ring-buffer semantics. Current inputs registered via add_current_input are summed with x and delayed together.

Integration notes:

  • RKF45 uses local error tolerance gsl_error_tol (default 1e-3).

  • The adaptive step size integration_step is persistent per neuron and typically stabilizes after a few milliseconds.

  • Maximum iterations: 10000 per timestep (prevents infinite loops).

  • Minimum step size: 1e-8 ms (prevents numerical instability).