iaf_psc_exp_ps#

class brainpy.state.iaf_psc_exp_ps(in_size, E_L=Quantity(-70., "mV"), C_m=Quantity(250., "pF"), tau_m=Quantity(10., "ms"), t_ref=Quantity(2., "ms"), V_th=Quantity(-55., "mV"), V_reset=Quantity(-70., "mV"), tau_syn_ex=Quantity(2., "ms"), tau_syn_in=Quantity(2., "ms"), I_e=Quantity(0., "pA"), V_min=None, V_initializer=Constant(value=-70. mV), spk_fun=ReluGrad(alpha=0.3, width=1.0), spk_reset='hard', ref_var=False, name=None)#

NEST-compatible iaf_psc_exp_ps with precise spike times.

Description

iaf_psc_exp_ps is a current-based leaky integrate-and-fire neuron with exponential excitatory/inhibitory PSC states and off-grid event/spike timing. The implementation follows NEST models/iaf_psc_exp_ps.{h,cpp} semantics: within-step event ordering by precise offsets, exact closed-form mini-step propagation, sub-step threshold localization by root search, and refractory release modeled as an explicit pseudo-event.

1. Continuous-time dynamics and exact integration

Let \(U = V_m - E_L\), \(I_{ex}\) and \(I_{in}\) be excitatory/inhibitory PSC states (pA), and \(y_0\) the one-step buffered continuous input current (pA). Subthreshold dynamics are

\[\frac{dU}{dt} = -\frac{U}{\tau_m} + \frac{I_e + y_0 + I_{ex} + I_{in}}{C_m},\]
\[\frac{dI_{ex}}{dt} = -\frac{I_{ex}}{\tau_{syn,ex}}, \qquad \frac{dI_{in}}{dt} = -\frac{I_{in}}{\tau_{syn,in}}.\]

Over a mini-interval \(\Delta t\), exact integration gives

\[U(t+\Delta t) = P_{20}(\Delta t)\,(I_e+y_0) + P_{21,ex}(\Delta t)\,I_{ex}(t) + P_{21,in}(\Delta t)\,I_{in}(t) + U(t)e^{-\Delta t/\tau_m},\]

where \(P_{20}=-\frac{\tau_m}{C_m}\left(e^{-\Delta t/\tau_m}-1\right)\) and \(P_{21,X}\) are evaluated by propagator_exp() (from _utils). PSC states decay exactly via \(I_X(t+\Delta t)=I_X(t)e^{-\Delta t/\tau_{syn,X}}\).

2. Precise-time event processing

Event offsets use NEST convention: offset=dt at step start and offset=0 at step end. For each global step:

  1. Build local event list from spike_events and on-grid delta input (always added at offset=0).

  2. Sort events in descending offset and split the step into mini-intervals.

  3. Propagate exactly on each mini-interval.

  4. If \(U\) reaches threshold, solve \(f(\delta)=U(\delta)-U_{th}=0\) with bounded bisection (64 iterations) to obtain off-grid spike time.

  5. Reset to V_reset and enter refractory state; release from refractory occurs through a pseudo-event when step_idx + 1 - last_spike_step == ceil(t_ref / dt).

3. Assumptions, constraints, and computational complexity

  • Parameters are scalar or broadcastable to self.varshape.

  • Construction-time constraints enforce V_reset < V_th, C_m > 0, tau_m > 0, tau_syn_ex > 0, tau_syn_in > 0, and when V_min is provided: V_reset >= V_min.

  • Runtime requires ceil(t_ref / dt) >= 1.

  • All precise offsets must satisfy 0 <= offset <= dt.

  • Continuous input x is buffered (stored into y0 for the next global step), matching NEST current-event timing.

  • Per-step complexity is \(O(|\mathrm{state}| \cdot K)\) for K local events, plus root search cost on threshold-crossing mini-intervals.

Parameters:
  • in_size (Size) – Population shape specification. Model parameters and states are broadcast to self.varshape derived from in_size.

  • E_L (ArrayLike, optional) – Resting potential \(E_L\) in mV, broadcastable to self.varshape. Default is -70. * u.mV.

  • C_m (ArrayLike, optional) – Membrane capacitance \(C_m\) in pF, broadcastable to self.varshape. Must be strictly positive elementwise. Default is 250. * u.pF.

  • tau_m (ArrayLike, optional) – Membrane time constant \(\tau_m\) in ms, broadcastable to self.varshape. Must be strictly positive elementwise. Default is 10. * u.ms.

  • t_ref (ArrayLike, optional) – Absolute refractory duration \(t_{ref}\) in ms, broadcastable to self.varshape. Converted at runtime to steps using ceil(t_ref / dt) and must produce at least one step. Default is 2. * u.ms.

  • V_th (ArrayLike, optional) – Threshold voltage \(V_{th}\) in mV, broadcastable to self.varshape. Default is -55. * u.mV.

  • V_reset (ArrayLike, optional) – Reset voltage \(V_{reset}\) in mV, broadcastable to self.varshape. Must satisfy V_reset < V_th elementwise. Default is -70. * u.mV.

  • tau_syn_ex (ArrayLike, optional) – Excitatory PSC decay constant \(\tau_{syn,ex}\) in ms, broadcastable to self.varshape and strictly positive. Default is 2. * u.ms.

  • tau_syn_in (ArrayLike, optional) – Inhibitory PSC decay constant \(\tau_{syn,in}\) in ms, broadcastable to self.varshape and strictly positive. Default is 2. * u.ms.

  • I_e (ArrayLike, optional) – Constant external current \(I_e\) in pA, broadcastable to self.varshape. Added in each mini-step propagation. Default is 0. * u.pA.

  • V_min (ArrayLike or None, optional) – Optional lower bound \(V_{min}\) in mV, broadcastable to self.varshape. If None, no lower clip is applied. Default is None.

  • V_initializer (Callable, optional) – Initializer used by init_state() for membrane state V. Must return mV-compatible values with shape compatible with self.varshape (and optional batch prefix). Default is braintools.init.Constant(-70. * u.mV).

  • spk_fun (Callable, optional) – Surrogate spike function used by get_spike() and update(). Receives normalized threshold distance tensor. Default is braintools.surrogate.ReluGrad().

  • spk_reset (str, optional) – Reset policy forwarded to Neuron. 'hard' matches NEST hard reset. Default is 'hard'.

  • ref_var (bool, optional) – If True, creates exposed self.refractory mirroring self.is_refractory for inspection. Default is False.

  • name (str or None, optional) – Optional node name.

Parameter Mapping

Table 9 Parameter mapping to model symbols#

Parameter

Type / shape / unit

Default

Math symbol

Semantics

in_size

Size; scalar/tuple

required

Defines self.varshape for parameter/state broadcasting.

E_L

ArrayLike, broadcastable to self.varshape (mV)

-70. * u.mV

\(E_L\)

Resting potential and voltage-offset origin.

C_m

ArrayLike, broadcastable (pF), > 0

250. * u.pF

\(C_m\)

Converts current terms to membrane-rate contribution.

tau_m

ArrayLike, broadcastable (ms), > 0

10. * u.ms

\(\tau_m\)

Leak time constant in exact subthreshold propagation.

t_ref

ArrayLike, broadcastable (ms), runtime ceil(t_ref/dt) >= 1

2. * u.ms

\(t_{ref}\)

Absolute refractory duration.

V_th and V_reset

ArrayLike, broadcastable (mV), with V_reset < V_th

-55. * u.mV, -70. * u.mV

\(V_{th}\), \(V_{reset}\)

Threshold and post-spike reset levels.

tau_syn_ex and tau_syn_in

ArrayLike, broadcastable (ms), each > 0

2. * u.ms

\(\tau_{syn,ex}\), \(\tau_{syn,in}\)

Exponential PSC decay constants.

I_e

ArrayLike, broadcastable (pA)

0. * u.pA

\(I_e\)

Constant injected current added every mini-step.

V_min

ArrayLike broadcastable (mV) or None

None

\(V_{min}\)

Optional lower clamp applied after membrane propagation.

V_initializer

Callable returning mV-compatible values

Constant(-70. * u.mV)

Initializes membrane state V.

spk_fun

Callable

ReluGrad()

Surrogate spike output nonlinearity.

spk_reset

str

'hard'

Reset mode inherited from base Neuron.

ref_var

bool

False

Allocate exposed refractory mirror state.

name

str | None

None

Optional node name.

Raises:
  • ValueError – If validated constraints fail (for example V_reset >= V_th, non-positive capacitance/time constants, V_reset < V_min, ceil(t_ref / dt) < 1, or event offsets outside [0, dt]).

  • TypeError – If provided arguments are incompatible with expected units/callables (mV, pA, pF, ms).

  • KeyError – If simulation context values t and/or dt are missing when update() is called.

  • AttributeError – If update() is called before init_state() creates required runtime states.

V#

Membrane potential state in mV.

Type:

HiddenState

I_syn_ex#

Excitatory PSC state in pA.

Type:

ShortTermState

I_syn_in#

Inhibitory PSC state in pA.

Type:

ShortTermState

y0#

One-step buffered continuous current in pA.

Type:

ShortTermState

is_refractory#

Boolean refractory mask.

Type:

ShortTermState

last_spike_step#

Step index of latest emitted spike.

Type:

ShortTermState

last_spike_offset#

Precise offset (ms) from right step boundary for latest spike.

Type:

ShortTermState

last_spike_time#

Absolute precise spike time in ms.

Type:

ShortTermState

refractory#

Optional mirror of is_refractory when ref_var=True.

Type:

ShortTermState

Notes

  • spike_events accepts (offset, weight) tuples or {'offset': ..., 'weight': ...} dicts.

  • Offsets are in ms and measured from the right edge of the current step.

  • Positive event weights contribute to excitatory PSC state; negative weights contribute to inhibitory PSC state.

  • Internal propagation and root finding are evaluated in NumPy float64 and written back into BrainUnit states at end of step.

Examples

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     neu = brainpy.state.iaf_psc_exp_ps(in_size=2, I_e=200.0 * u.pA)
...     neu.init_state()
...     with brainstate.environ.context(t=0.0 * u.ms):
...         spk = neu.update()
...     _ = spk.shape
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     neu = brainpy.state.iaf_psc_exp_ps(in_size=1)
...     neu.init_state()
...     ev = [{'offset': 0.08 * u.ms, 'weight': 120.0 * u.pA}]
...     with brainstate.environ.context(t=0.0 * u.ms):
...         _ = neu.update(spike_events=ev)

References

get_spike(V=None)[source]#

Evaluate surrogate spike output from membrane potential.

Applies the surrogate spike function (typically braintools.surrogate.ReluGrad or similar) to a normalized threshold-distance metric. This enables differentiable spike generation for gradient-based learning while maintaining biological spike semantics.

The normalized threshold distance is computed as \((V - V_{th}) / (V_{th} - V_{reset})\), which maps the voltage range between reset and threshold to [0, 1], with values above threshold producing positive outputs through the surrogate function.

Parameters:

V (ArrayLike or None, optional) – Voltage tensor in mV, broadcast-compatible with self.varshape (or current batched state shape). If None, uses self.V.value. Default is None.

Returns:

out – Output of self.spk_fun applied to normalized threshold distance (V - V_th) / (V_th - V_reset) with same shape as input V. Typically float values in [0, 1] or similar range depending on the surrogate function’s output characteristics.

Return type:

dict

Raises:
  • TypeError – If V is not compatible with unit arithmetic in mV or if unit conversion operations fail.

  • AttributeError – If self.spk_fun is not callable or if required parameters (V_th, V_reset) are not available.

init_state(**kwargs)[source]#

Initialize membrane, synaptic, and precise-timing runtime states.

This method allocates all internal state variables required for precise spike-time simulation. Membrane potential V is initialized using self.V_initializer, synaptic currents and buffered inputs are initialized to zero, and spike-tracking states are initialized to sentinel values (last_spike_step = -1, last_spike_time = -1e7 ms) indicating no prior spike events.

Parameters:

**kwargs (Any) – Unused compatibility arguments for subclass extension.

Raises:
  • ValueError – If initializer outputs cannot be broadcast to state shape self.varshape or if shapes are incompatible.

  • TypeError – If initializer outputs are not unit-compatible with expected state units (mV for voltage, pA for currents, ms for time, bool for flags).

  • AttributeError – If self.V_initializer is not callable or does not produce valid output for the requested shape.

update(x=Quantity(0., 'pA'), spike_events=None)[source]#

Advance one global step with precise within-step event handling.

This method implements the complete NEST-compatible precise-spike-time algorithm for iaf_psc_exp_ps. Each global time step is subdivided into mini-intervals determined by spike event offsets. Within each mini-interval, membrane potential and synaptic currents are propagated exactly using closed-form exponential solutions. When the membrane potential crosses threshold, bisection root-finding (64 iterations) localizes the precise sub-step spike time.

Update sequence:

  1. Parse and validate spike_events and on-grid delta inputs.

  2. Sort events in descending offset (from step start to step end).

  3. For each neuron, process events sequentially:

    1. Propagate states exactly over each mini-interval.

    2. Apply event weights to PSC states (ex/in channels by sign).

    3. Check for threshold crossing and localize spike time if needed.

    4. Apply hard reset and enter refractory state on spike.

    5. Release from refractory via pseudo-event at calculated step.

  4. Buffer incoming current x into y0 for next step.

  5. Compute surrogate spike output for gradient-based learning.

Implementation notes:

  • All propagation uses NumPy float64 for numerical stability.

  • Event offsets follow NEST convention: offset=dt at step start, offset=0 at step end.

  • Refractory neurons clamp membrane potential but allow PSC decay.

  • Root finding uses bounded bisection over [0, dt] with 64 iterations.

Parameters:
  • x (ArrayLike, optional) – Continuous current input in pA for the current global step. Aggregated through sum_current_inputs() and stored in y0 for use in the next step (one-step buffering). Scalar or array-like broadcastable to self.V.value.shape. Default is 0. * u.pA.

  • spike_events (Iterable[tuple[Any, Any] | dict[str, Any]] or None, optional) – Optional off-grid events inside the current step. Each entry is (offset, weight) or {'offset': ..., 'weight': ...}, where offset is in ms measured from the right step boundary and weight is in pA. Offsets must satisfy 0 <= offset <= dt. Positive weights update excitatory PSC; negative weights update inhibitory PSC. None means no extra precise events. On-grid delta inputs are automatically included at offset=0. Default is None.

Returns:

out – Surrogate spike output from get_spike(), shape self.V.value.shape. Values correspond to self.spk_fun((V - V_th) / (V_th - V_reset)) after exact piecewise propagation, event application, refractory logic, and precise spike-time localization. For neurons that spiked, the voltage is clamped slightly above threshold to ensure differentiable spike detection; for non-spiking neurons, voltage is clamped below threshold.

Return type:

jax.Array

Raises:
  • ValueError – If ceil(t_ref / dt) < 1 (refractory period too short for time step), or if any event offset lies outside [0, dt], or if parameter constraints are violated at runtime.

  • KeyError – If simulation context values t (current time) or dt (time step) are unavailable from brainstate.environ.

  • TypeError – If x or spike_events entries are not unit-compatible with pA/ms conversions, or if type conversions fail during numerical computation.

  • AttributeError – If required runtime states (V, I_syn_ex, I_syn_in, y0, is_refractory, etc.) are missing because init_state() has not been called.