spike_train_injector

spike_train_injector#

class brainpy.state.spike_train_injector(in_size=1, spike_times=(), spike_multiplicities=(), precise_times=False, allow_offgrid_times=False, shift_now_spikes=False, start=Quantity(0., 'ms'), stop=None, origin=Quantity(0., 'ms'), name=None)#

Spike train injector – NEST-compatible event source device.

Emit deterministic spike events at configured times with optional per-time multiplicity, then gate output by a half-open activity window. Unlike spike_generator, which selects the last matching weight, this device accumulates all multiplicities that match the current step, making it suitable for injecting pre-recorded spike trains where multiple events may be scheduled at the same simulation time.

1. Model equations

Let \(\{t_i\}_{i=1}^{K}\) be configured spike times in ms after conversion from unitful or unitless inputs. Let \(m_i\) denote multiplicity (spike_multiplicities) when provided, otherwise \(m_i = 1\). At simulation time \(t\) with step \(\Delta t\) (both in ms), define the matching indicator

\[q_i(t) = \mathbf{1}\!\left[|t - t_i| < \frac{\Delta t}{2}\right].\]

The scalar emitted spike count before window gating is

\[a(t) = \sum_{i=1}^{K} m_i\, q_i(t).\]

The activity gate is

\[g(t) = \mathbf{1}\!\left[t \ge t_0 + t_{\mathrm{start,rel}}\right] \cdot \mathbf{1}\!\left[t < t_0 + t_{\mathrm{stop,rel}}\right],\]

where the second indicator is omitted when stop is None. The returned output is broadcast to node shape self.varshape:

\[y(t) = g(t)\,a(t)\,\mathbf{1}_{\mathrm{varshape}}.\]

2. Timing derivation, assumptions, and constraints

The \(|t - t_i| < \Delta t / 2\) rule corresponds to nearest-grid assignment under uniform-step simulation. For exact half-step offsets, strict inequality means no match at that boundary. If multiple spike_times entries map to the same step, their multiplicities are summed, giving \(a(t) > 1\) for bursts.

Enforced constraints:

  • spike_times must be non-descending after conversion.

  • spike_multiplicities must be empty or have exactly len(spike_multiplicities) == len(spike_times) elements.

  • precise_times=True cannot be combined with allow_offgrid_times=True or shift_now_spikes=True.

Implementation-specific constraints:

  • NEST option flags precise_times, allow_offgrid_times, and shift_now_spikes are accepted for API compatibility but the current update rule always uses the fixed tolerance test above regardless of their values.

  • NEST documentation states spikes should be strictly in the future. This implementation does not perform explicit future-time validation in __init__() and instead relies on runtime matching combined with active-window gating.

3. Computational implications

Each update() call uses u.math.searchsorted() to locate the open interval \((t - \Delta t/2,\, t + \Delta t/2)\) in the sorted spike_times array, yielding a range \([\textit{idx\_lo}, \textit{idx\_hi})\) of matching indices. A Boolean mask over \(\{0,\ldots,K-1\}\) is then used to sum the multiplicities of all matching entries. Per-call complexity is \(O(\log K + K + \prod \mathrm{varshape})\). The update() method is fully compatible with jax.jit: no Python control flow branches on traced values.

Parameters:
  • in_size (Size, optional) – Output size/shape consumed by brainstate.nn.Dynamics. The emitted array has shape self.varshape derived from in_size. Default is 1.

  • spike_times (Sequence, optional) – Sequence of spike times with length K. Entries may be unitful times (typically saiunit ms quantities) or bare numerics interpreted as ms. Passed directly to u.math.asarray(), which validates unit consistency across all entries. Must be non-descending. Duplicate times are allowed and their multiplicities are accumulated. Default is ().

  • spike_multiplicities (Sequence, optional) – Sequence of integer multiplicities with length K matching spike_times, or empty to use implicit unit multiplicities (\(m_i = 1\)). Entries are converted with int(m) and stored as a dimensionless JAX array; accumulated across all indices matching the same step. Default is ().

  • precise_times (bool, optional) – NEST compatibility flag for sub-step precise timing. Stored and validated against allow_offgrid_times / shift_now_spikes but not used to alter runtime matching in this implementation. Default is False.

  • allow_offgrid_times (bool, optional) – NEST compatibility flag permitting off-grid spike times. Stored and validated but not used to alter runtime matching in this implementation. Default is False.

  • shift_now_spikes (bool, optional) – NEST compatibility flag for shifting spikes that would fire at the current step to the next. Stored and validated but not used to alter runtime matching in this implementation. Default is False.

  • start (ArrayLike, optional) – Relative activation time \(t_{\mathrm{start,rel}}\) (typically ms), initialized via braintools.init.param(). The effective inclusive lower bound of the active window is origin + start. Default is 0. * u.ms.

  • stop (ArrayLike or None, optional) – Relative deactivation time \(t_{\mathrm{stop,rel}}\) (typically ms), initialized via braintools.init.param() when not None. The effective exclusive upper bound is origin + stop. None disables the upper bound. Default is None.

  • origin (ArrayLike, optional) – Global time origin \(t_0\) (typically ms) added to both start and stop to obtain absolute window bounds. Default is 0. * u.ms.

  • name (str or None, optional) – Optional node name forwarded to brainstate.nn.Dynamics.

Parameter Mapping

Table 29 Parameter mapping to model symbols#

Parameter

Default

Math symbol

Semantics

spike_times

()

\(t_i\)

Spike schedule; matched by |t - t_i| < dt/2.

spike_multiplicities

()

\(m_i\)

Per-time spike count; empty means implicit \(m_i = 1\).

start

0. * u.ms

\(t_{\mathrm{start,rel}}\)

Relative inclusive lower bound of active window.

stop

None

\(t_{\mathrm{stop,rel}}\)

Relative exclusive upper bound; None means unbounded.

origin

0. * u.ms

\(t_0\)

Global offset applied to start and stop.

Raises:
  • ValueError – If precise_times=True is combined with allow_offgrid_times=True or shift_now_spikes=True, if spike_times is not non-descending after conversion, or if spike_multiplicities is non-empty and has a different length than spike_times.

  • TypeError – If u.math.asarray() detects unit inconsistency across entries, or if unitful/unitless arithmetic is invalid during time-window comparisons.

  • KeyError – At update time, if required simulation context entries (e.g. 't' or dt) are absent from brainstate.environ.

Notes

This device does not accept incoming synaptic or current connections; it only emits scheduled events. The output is dimensionless (spike count per step) and is typically consumed by a downstream synapse model that scales by connection weight.

The key behavioral difference from spike_generator is accumulation: when two entries in spike_times round to the same step, spike_train_injector sums their multiplicities while spike_generator retains only the last matching weight. Use spike_train_injector when replaying recorded spike trains that may contain bursts, and spike_generator when a single weighted event per step is intended.

Spike times should ideally be aligned to the simulation grid (multiples of dt) to avoid off-by-one steps. The tolerance dt/2 covers one-ULP rounding for grid-aligned times in typical float64 arithmetic.

See also

spike_generator

Deterministic spike device with per-spike weights (last-match semantics).

dc_generator

Constant-current stimulation device.

ac_generator

Sinusoidal current stimulation device.

step_current_generator

Piecewise-constant current stimulation device.

References

Examples

Inject a burst of five spikes at t = 2 ms (two entries map to the same step, multiplicities are accumulated to give a = 2 + 3 = 5):

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     inj = brainpy.state.spike_train_injector(
...         spike_times=[1.0 * u.ms, 2.0 * u.ms, 2.0 * u.ms],
...         spike_multiplicities=[1, 2, 3],
...         start=0.0 * u.ms,
...         stop=5.0 * u.ms,
...     )
...     with brainstate.environ.context(t=2.0 * u.ms):
...         out = inj.update()
...     _ = out.shape

Inject a single spike at t = 10 ms using NEST’s precise_times flag for API compatibility (sub-step resolution not enforced here):

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
...     inj = brainpy.state.spike_train_injector(
...         spike_times=[10.0 * u.ms],
...         precise_times=True,
...     )
...     with brainstate.environ.context(t=10.0 * u.ms):
...         out = inj.update()
...     _ = out.shape
update()[source]#

Compute the accumulated spike output for the current simulation step.

The implementation is fully compatible with jax.jit: spike-time matching uses u.math.searchsorted() on the static spike_times array while t and dt remain traced values throughout. The multiplicity sum uses a Boolean mask with no Python branching over traced values.

Returns:

out – Float-valued JAX array with shape self.varshape. Output semantics:

  • 0 when outside [origin + start, origin + stop) (or [origin + start, +inf) if stop is None),

  • 0 when active but no configured spike satisfies |t - t_i| < dt/2,

  • accumulated integer multiplicity \(a(t) = \sum_i m_i\, q_i(t)\) when active and one or more spikes match.

Return type:

jax.Array

Raises:

KeyError – If required simulation context entries are missing from brainstate.environ (e.g. 't' or dt).

Notes

Matching uses the open interval \((t - \Delta t/2,\, t + \Delta t/2)\) located via two u.math.searchsorted() calls:

  • idx_lo = searchsorted(times, t - dt/2, side='right') — first index strictly greater than the lower bound.

  • idx_hi = searchsorted(times, t + dt/2, side='left') — first index at or above the upper bound.

A Boolean mask indices in [idx_lo, idx_hi) selects all matching entries; their multiplicities (or 1s if none configured) are summed to obtain the scalar count \(a(t)\). Start is inclusive and stop is exclusive, matching NEST semantics.

Unlike spike_generator.update(), which keeps only the last matching weight, this method accumulates all matching multiplicities. A burst of three spikes scheduled at the same time thus returns 3 (or the sum of their individual multiplicities).

See also

spike_train_injector

Class-level parameter definitions and equations.

spike_generator.update

Weight-selection (last-match) update rule.

dc_generator.update

Windowed constant-current update rule.

step_current_generator.update

Windowed piecewise-constant update rule.