aeif_cond_exp#

class brainpy.state.aeif_cond_exp(in_size, V_peak=Quantity(0., "mV"), V_reset=Quantity(-60., "mV"), t_ref=Quantity(0., "ms"), g_L=Quantity(30., "nS"), C_m=Quantity(281., "pF"), E_ex=Quantity(0., "mV"), E_in=Quantity(-85., "mV"), E_L=Quantity(-70.6, "mV"), Delta_T=Quantity(2., "mV"), tau_w=Quantity(144., "ms"), a=Quantity(4., "nS"), b=Quantity(80.5, "pA"), V_th=Quantity(-50.4, "mV"), tau_syn_ex=Quantity(0.2, "ms"), tau_syn_in=Quantity(2., "ms"), I_e=Quantity(0., "pA"), gsl_error_tol=1e-06, V_initializer=Constant(value=-70.6 mV), g_ex_initializer=Constant(value=0. nS), g_in_initializer=Constant(value=0. nS), w_initializer=Constant(value=0. pA), spk_fun=ReluGrad(alpha=0.3, width=1.0), spk_reset='hard', ref_var=False, name=None)#

NEST-compatible aeif_cond_exp neuron model.

Conductance-based adaptive exponential integrate-and-fire neuron with exponential synaptic conductances.

This implementation follows NEST models/aeif_cond_exp.{h,cpp} and combines exponential spike-initiation current (AdEx), spike-triggered and subthreshold adaptation current, and exponentially decaying excitatory/inhibitory conductances.

1. Membrane, Synapse, and Adaptation Dynamics

The membrane potential \(V\) evolves according to:

\[C_m \frac{dV}{dt} = -g_L (V - E_L) + g_L \Delta_T \exp\!\left(\frac{V - V_{th}}{\Delta_T}\right) - g_{ex}(V - E_{ex}) - g_{in}(V - E_{in}) - w + I_e + I_{stim}\]

where the first term is the leak current, the second term is the exponential spike-initiation current (the hallmark of the AdEx model), the third and fourth terms are excitatory and inhibitory synaptic currents, \(w\) is the adaptation current, \(I_e\) is constant external current, and \(I_{stim}\) is the time-varying stimulation current.

The adaptation current \(w\) follows:

\[\tau_w \frac{dw}{dt} = a (V - E_L) - w\]

where \(a\) controls subthreshold adaptation (coupling between \(V\) and \(w\)) and \(\tau_w\) is the adaptation time constant.

Excitatory and inhibitory conductances decay exponentially:

\[\frac{d g_{ex}}{dt} = -\frac{g_{ex}}{\tau_{syn,ex}}, \qquad \frac{d g_{in}}{dt} = -\frac{g_{in}}{\tau_{syn,in}}\]

Incoming spike weights (in nS) are split by sign and added to the respective conductances:

\[g_{ex} \leftarrow g_{ex} + w_+, \qquad g_{in} \leftarrow g_{in} + |w_-|\]

2. Refractory Period and Spike Handling (NEST Semantics)

During refractory integration (when refractory_step_count > 0), the effective membrane voltage is clamped to V_reset and \(dV/dt = 0\). Outside refractory periods, the right-hand side uses \(\min(V, V_{peak})\) as the effective voltage to prevent numerical overflow in the exponential term.

Spike detection threshold:

  • If Delta_T > 0: spike when \(V \geq V_{peak}\)

  • If Delta_T == 0 (IAF-like limit): spike when \(V \geq V_{th}\)

Upon spike detection:

  1. \(V\) is reset to V_reset

  2. Adaptation jump \(w \leftarrow w + b\) is applied immediately

  3. Refractory counter is set to ceil(t_ref / dt) + 1 if t_ref > 0

Spike detection and reset occur inside the adaptive RKF45 integration substep loop. Therefore, with t_ref = 0, multiple spikes can occur within one simulation step, matching NEST behavior.

3. Update Order per Simulation Step

Each call to update(x) performs the following sequence:

  1. Integrate ODEs on \((t, t+dt]\) via adaptive RKF45 with local error control

  2. Inside integration loop: apply refractory clamp and spike/reset/adaptation as needed

  3. After integration loop: decrement refractory counter once (if > 0)

  4. Apply arriving spike weights (from delta_inputs) to g_ex / g_in

  5. Store external current input x into one-step delayed buffer I_stim (for use in the next time step)

4. Numerical Integration Details

The model uses an adaptive Runge-Kutta-Fehlberg 4(5) integrator (RKF45) with local error control. Step size is dynamically adjusted based on gsl_error_tol. The integration step size is stored in integration_step and persists across time steps for efficiency. Minimum step size is clamped to _MIN_H = 1e-8 ms to prevent stalling. Maximum iterations per time step is _MAX_ITERS = 100000. If membrane potential drops below -1000 mV or adaptation current exceeds ±1e6 pA, a numerical instability error is raised.

5. Computational Constraints and Assumptions

  • Overflow guard: The exponential term can overflow if (V_peak - V_th) / Delta_T is too large. The model validates that this ratio stays below log(max_float64 / 1e20) at initialization, mirroring NEST’s safeguard.

  • Refractory clamp: During refractory period, \(V\) is clamped to V_reset and \(dV/dt = 0\), but all other variables (g_ex, g_in, w) continue to evolve normally.

  • Hard spike reset: By default, spk_reset='hard' uses jax.lax.stop_gradient to prevent gradient flow through spike times, matching typical neuroscience practice.

  • Delayed input: The current input x from time \(t\) is stored in I_stim and used during integration from \(t+dt\) to \(t+2dt\). This one-step delay matches NEST’s input handling.

Parameters:
  • in_size (Size (int, tuple of int, or callable returning shape)) – Neuron population shape. Supports integer (1D), tuple (multi-dimensional), or callable returning shape.

  • V_peak (ArrayLike, optional) – Spike detection threshold (mV). Used when Delta_T > 0. Default: 0 mV. Must satisfy V_peak >= V_th and V_peak > V_reset.

  • V_reset (ArrayLike, optional) – Reset potential (mV) after spike. Default: -60 mV. Must satisfy V_reset < V_peak.

  • t_ref (ArrayLike, optional) – Absolute refractory period (ms). Default: 0 ms. When 0, multiple spikes per simulation step are possible. Must be non-negative.

  • g_L (ArrayLike, optional) – Leak conductance (nS). Default: 30 nS. Must be positive.

  • C_m (ArrayLike, optional) – Membrane capacitance (pF). Default: 281 pF. Must be positive.

  • E_ex (ArrayLike, optional) – Excitatory reversal potential (mV). Default: 0 mV.

  • E_in (ArrayLike, optional) – Inhibitory reversal potential (mV). Default: -85 mV.

  • E_L (ArrayLike, optional) – Leak reversal potential (mV). Default: -70.6 mV.

  • Delta_T (ArrayLike, optional) – Exponential slope factor (mV) controlling sharpness of spike initiation. Default: 2 mV. Must be non-negative. Set to 0 to recover IAF-like behavior.

  • tau_w (ArrayLike, optional) – Adaptation time constant (ms). Default: 144 ms. Must be positive.

  • a (ArrayLike, optional) – Subthreshold adaptation coupling (nS). Default: 4 nS. Controls how strongly membrane potential drives adaptation current.

  • b (ArrayLike, optional) – Spike-triggered adaptation increment (pA). Default: 80.5 pA. Added to w on each spike.

  • V_th (ArrayLike, optional) – Spike initiation threshold (mV) appearing in exponential term. Default: -50.4 mV. Must satisfy V_th <= V_peak.

  • tau_syn_ex (ArrayLike, optional) – Excitatory conductance decay time constant (ms). Default: 0.2 ms. Must be positive.

  • tau_syn_in (ArrayLike, optional) – Inhibitory conductance decay time constant (ms). Default: 2.0 ms. Must be positive.

  • I_e (ArrayLike, optional) – Constant external current (pA). Default: 0 pA.

  • gsl_error_tol (ArrayLike, optional) – RKF45 local error tolerance (unitless). Default: 1e-6. Smaller values increase accuracy but slow integration. Must be positive.

  • V_initializer (Callable, optional) – Initializer for membrane potential. Default: Constant(-70.6 mV).

  • g_ex_initializer (Callable, optional) – Initializer for excitatory conductance. Default: Constant(0 nS).

  • g_in_initializer (Callable, optional) – Initializer for inhibitory conductance. Default: Constant(0 nS).

  • w_initializer (Callable, optional) – Initializer for adaptation current. Default: Constant(0 pA).

  • spk_fun (Callable, optional) – Surrogate gradient function for differentiable spike generation. Default: ReluGrad(). Used in get_spike() for gradient-based learning.

  • spk_reset (str, optional) – Spike reset mode. Default: 'hard' (stop gradient). Use 'soft' to allow gradient flow through spike times.

  • ref_var (bool, optional) – If True, expose boolean refractory state variable indicating whether neuron is in refractory period. Default: False.

  • name (str, optional) – Name of the neuron group. Default: None.

Parameter Mapping

This table shows the correspondence between brainpy.state parameters, NEST parameters, and mathematical notation:

brainpy.state

NEST

Math Symbol

Description

in_size

(N/A)

Population shape

V_peak

V_peak

\(V_\mathrm{peak}\)

Spike detection threshold (if Delta_T > 0)

V_reset

V_reset

\(V_\mathrm{reset}\)

Reset potential

t_ref

t_ref

\(t_\mathrm{ref}\)

Absolute refractory duration

g_L

g_L

\(g_\mathrm{L}\)

Leak conductance

C_m

C_m

\(C_\mathrm{m}\)

Membrane capacitance

E_ex

E_ex

\(E_\mathrm{ex}\)

Excitatory reversal potential

E_in

E_in

\(E_\mathrm{in}\)

Inhibitory reversal potential

E_L

E_L

\(E_\mathrm{L}\)

Leak reversal potential

Delta_T

Delta_T

\(\Delta_T\)

Exponential slope factor

tau_w

tau_w

\(\tau_w\)

Adaptation time constant

a

a

\(a\)

Subthreshold adaptation coupling

b

b

\(b\)

Spike-triggered adaptation increment

V_th

V_th

\(V_\mathrm{th}\)

Spike initiation threshold

tau_syn_ex

tau_syn_ex

\(\tau_{\mathrm{syn,ex}}\)

Excitatory conductance time constant

tau_syn_in

tau_syn_in

\(\tau_{\mathrm{syn,in}}\)

Inhibitory conductance time constant

I_e

I_e

\(I_\mathrm{e}\)

Constant external current

gsl_error_tol

gsl_error_tol

RKF45 solver tolerance

V#

Membrane potential (mV). Shape: (batch_size,) + varshape.

Type:

HiddenState

g_ex#

Excitatory conductance (nS). Shape: (batch_size,) + varshape.

Type:

HiddenState

g_in#

Inhibitory conductance (nS). Shape: (batch_size,) + varshape.

Type:

HiddenState

w#

Adaptation current (pA). Shape: (batch_size,) + varshape.

Type:

HiddenState

refractory_step_count#

Remaining refractory time steps (int32). Shape: (batch_size,) + varshape.

Type:

ShortTermState

integration_step#

Persistent RKF45 internal step size (ms). Shape: (batch_size,) + varshape.

Type:

ShortTermState

I_stim#

One-step delayed current buffer (pA). Shape: (batch_size,) + varshape.

Type:

ShortTermState

last_spike_time#

Last emitted spike time (ms). Updated to \(t + dt\) on spike. Shape: (batch_size,) + varshape.

Type:

ShortTermState

refractory#

Boolean refractory indicator. Only present if ref_var=True. Shape: (batch_size,) + varshape.

Type:

ShortTermState (optional)

Raises:
  • ValueError – If V_peak < V_th, Delta_T < 0, V_reset >= V_peak, C_m <= 0, t_ref < 0, any time constant <= 0, gsl_error_tol <= 0, or if (V_peak - V_th) / Delta_T would cause exponential overflow.

  • ValueError – During integration, if membrane potential drops below -1000 mV or adaptation current exceeds ±1e6 pA, indicating numerical instability.

Notes

  • Default refractory period: t_ref = 0 matches NEST and can allow multiple spikes per simulation step. Set t_ref > 0 to enforce absolute refractory period.

  • Spike output: The returned spike tensor is binary per step (0 or 1), even if multiple spikes occur internally. Use last_spike_time to track precise spike timing.

  • Gradient-based learning: Use get_spike() method for differentiable spike generation with surrogate gradients, suitable for gradient-based learning.

  • NEST compatibility: This implementation closely follows NEST’s C++ source, including refractory clamping, spike detection logic, and overflow guards.

Examples

Create and simulate a population of AdEx neurons:

>>> import brainpy.state as bst
>>> import saiunit as u
>>> import brainstate
>>> # Create 100 AdEx neurons
>>> neurons = bst.aeif_cond_exp(100)
>>> # Initialize states
>>> neurons.init_all_states()
>>> # Simulate with constant current input
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     for _ in range(1000):
...         spikes = neurons.update(x=500 * u.pA)

Create with custom parameters matching cortical pyramidal cells:

>>> neurons = bst.aeif_cond_exp(
...     in_size=100,
...     V_peak=0.0 * u.mV,
...     V_reset=-70.0 * u.mV,
...     t_ref=2.0 * u.ms,
...     g_L=30.0 * u.nS,
...     C_m=281.0 * u.pF,
...     Delta_T=2.0 * u.mV,
...     tau_w=144.0 * u.ms,
...     a=4.0 * u.nS,
...     b=80.5 * u.pA,
... )

Access state variables:

>>> neurons.init_all_states()
>>> print(neurons.V.value.shape)  # Membrane potential
>>> print(neurons.g_ex.value.shape)  # Excitatory conductance
>>> print(neurons.w.value.shape)  # Adaptation current
>>> print(neurons.refractory_step_count.value.shape)  # Refractory counter

References

get_spike(V=None)[source]#

Generate differentiable spike signal using surrogate gradient.

Computes a continuous spike probability using the surrogate gradient function (spk_fun) applied to scaled membrane potential. This enables gradient-based learning through spike generation.

Parameters:

V (ArrayLike, optional) – Membrane potential (mV). If None, uses current self.V.value. Shape: arbitrary, but typically (batch_size,) + varshape.

Returns:

spike_prob – Continuous spike signal in [0, 1]. Shape matches input V. Values near 0 indicate no spike, values near 1 indicate spike. Exact range depends on spk_fun (e.g., ReluGrad returns values in [0, 1]).

Return type:

ArrayLike

Notes

  • The membrane potential is scaled as (V - V_th) / (V_th - V_reset) before applying the surrogate function.

  • This method is primarily used for gradient-based learning and does NOT affect the hard spike detection used in update().

  • For binary spike output matching NEST semantics, use the return value of update().

init_state(**kwargs)[source]#

State initialization function.

update(x=Quantity(0., 'pA'))[source]#

Advance neuron state by one simulation time step.

Integrates the AdEx ODE system over interval \((t, t+dt]\) using adaptive RKF45 with local error control. Handles spike detection, reset, adaptation jumps, refractory clamping, and synaptic input processing following NEST semantics.

Parameters:

x (ArrayLike, optional) – External current input (pA) for the current time step. Default: 0 pA. Shape: scalar, varshape, or (batch_size,) + varshape. This input is stored in I_stim and will be used during the next time step (one-step delay, matching NEST).

Returns:

spike – Binary spike indicator (0 or 1, dtype float64). Shape: (batch_size,) + varshape. Value is 1 if at least one spike occurred during this time step, 0 otherwise. Multiple spikes within one step (when t_ref = 0) are compressed to a single binary flag.

Return type:

ArrayLike

Notes

Update sequence:

  1. ODE integration: Integrate \((t, t+dt]\) via adaptive RKF45. Inside the integration loop:

    • Apply refractory clamp if refractory_step_count > 0

    • Check for spike when \(V \geq V_{peak}\) (or \(V \geq V_{th}\) if Delta_T = 0)

    • On spike: reset \(V \leftarrow V_{reset}\), jump \(w \leftarrow w + b\), set refractory_step_count = ceil(t_ref / dt) + 1

  2. Post-integration: Decrement refractory_step_count once (if > 0)

  3. Synaptic input: Process delta_inputs (spike weights from projections), split by sign, and add to g_ex / g_in

  4. Delayed input buffer: Store current external input x in I_stim for use in the next time step

  5. Spike time tracking: Update last_spike_time to \(t + dt\) for neurons that spiked

Numerical integration details:

  • Uses Runge-Kutta-Fehlberg 4(5) with embedded error estimation

  • Step size is adaptive based on gsl_error_tol

  • Minimum step size: _MIN_H = 1e-8 ms

  • Maximum iterations: _MAX_ITERS = 100000 per simulation step

  • Step size is persistent across time steps (stored in integration_step)

Failure modes:

  • Raises ValueError if membrane potential drops below -1000 mV or adaptation current exceeds ±1e6 pA, indicating numerical instability (typically from bad parameters or extreme inputs)

  • Does NOT raise error if max iterations exceeded; instead completes integration with accumulated error (silent degradation)

Computational cost:

  • Per-neuron scalar integration (no vectorization across neurons)

  • Cost scales with 1/gsl_error_tol (smaller tolerance = more substeps)

  • Typical: 1-10 substeps per simulation step for standard parameters