Source code for brainpy_state._nest.spike_train_injector

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Sequence

import brainstate
import braintools
import saiunit as u
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'spike_train_injector',
]


class spike_train_injector(NESTDevice):
    r"""Spike train injector -- NEST-compatible event source device.

    Emit deterministic spike events at configured times with optional
    per-time multiplicity, then gate output by a half-open activity window.
    Unlike :class:`spike_generator`, which selects the last matching weight,
    this device *accumulates* all multiplicities that match the current step,
    making it suitable for injecting pre-recorded spike trains where multiple
    events may be scheduled at the same simulation time.

    **1. Model equations**

    Let :math:`\{t_i\}_{i=1}^{K}` be configured spike times in ms after
    conversion from unitful or unitless inputs. Let :math:`m_i` denote
    multiplicity (``spike_multiplicities``) when provided, otherwise
    :math:`m_i = 1`. At simulation time :math:`t` with step :math:`\Delta t`
    (both in ms), define the matching indicator

    .. math::

        q_i(t) = \mathbf{1}\!\left[|t - t_i| < \frac{\Delta t}{2}\right].

    The scalar emitted spike count before window gating is

    .. math::

        a(t) = \sum_{i=1}^{K} m_i\, q_i(t).

    The activity gate is

    .. math::

        g(t) =
        \mathbf{1}\!\left[t \ge t_0 + t_{\mathrm{start,rel}}\right]
        \cdot
        \mathbf{1}\!\left[t < t_0 + t_{\mathrm{stop,rel}}\right],

    where the second indicator is omitted when ``stop is None``.
    The returned output is broadcast to node shape ``self.varshape``:

    .. math::

        y(t) = g(t)\,a(t)\,\mathbf{1}_{\mathrm{varshape}}.

    **2. Timing derivation, assumptions, and constraints**

    The :math:`|t - t_i| < \Delta t / 2` rule corresponds to nearest-grid
    assignment under uniform-step simulation. For exact half-step offsets,
    strict inequality means no match at that boundary. If multiple
    ``spike_times`` entries map to the same step, their multiplicities are
    *summed*, giving :math:`a(t) > 1` for bursts.

    Enforced constraints:

    - ``spike_times`` must be non-descending after conversion.
    - ``spike_multiplicities`` must be empty or have exactly
      ``len(spike_multiplicities) == len(spike_times)`` elements.
    - ``precise_times=True`` cannot be combined with
      ``allow_offgrid_times=True`` or ``shift_now_spikes=True``.

    Implementation-specific constraints:

    - NEST option flags ``precise_times``, ``allow_offgrid_times``, and
      ``shift_now_spikes`` are accepted for API compatibility but the current
      update rule always uses the fixed tolerance test above regardless of
      their values.
    - NEST documentation states spikes should be strictly in the future. This
      implementation does not perform explicit future-time validation in
      :meth:`__init__` and instead relies on runtime matching combined with
      active-window gating.

    **3. Computational implications**

    Each :meth:`update` call uses :func:`u.math.searchsorted` to locate the
    open interval :math:`(t - \Delta t/2,\, t + \Delta t/2)` in the sorted
    ``spike_times`` array, yielding a range :math:`[\textit{idx\_lo},
    \textit{idx\_hi})` of matching indices. A Boolean mask over
    :math:`\{0,\ldots,K-1\}` is then used to sum the multiplicities of all
    matching entries. Per-call complexity is :math:`O(\log K + K + \prod
    \mathrm{varshape})`.  The :meth:`update` method is fully compatible with
    ``jax.jit``: no Python control flow branches on traced values.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape consumed by :class:`brainstate.nn.Dynamics`. The
        emitted array has shape ``self.varshape`` derived from ``in_size``.
        Default is ``1``.
    spike_times : Sequence, optional
        Sequence of spike times with length ``K``. Entries may be unitful
        times (typically ``saiunit`` ms quantities) or bare numerics
        interpreted as ms. Passed directly to :func:`u.math.asarray`, which
        validates unit consistency across all entries. Must be non-descending.
        Duplicate times are allowed and their multiplicities are accumulated.
        Default is ``()``.
    spike_multiplicities : Sequence, optional
        Sequence of integer multiplicities with length ``K`` matching
        ``spike_times``, or empty to use implicit unit multiplicities
        (:math:`m_i = 1`). Entries are converted with ``int(m)`` and stored
        as a dimensionless JAX array; accumulated across all indices matching
        the same step. Default is ``()``.
    precise_times : bool, optional
        NEST compatibility flag for sub-step precise timing. Stored and
        validated against ``allow_offgrid_times`` / ``shift_now_spikes`` but
        not used to alter runtime matching in this implementation.
        Default is ``False``.
    allow_offgrid_times : bool, optional
        NEST compatibility flag permitting off-grid spike times. Stored and
        validated but not used to alter runtime matching in this
        implementation. Default is ``False``.
    shift_now_spikes : bool, optional
        NEST compatibility flag for shifting spikes that would fire at the
        current step to the next. Stored and validated but not used to alter
        runtime matching in this implementation. Default is ``False``.
    start : ArrayLike, optional
        Relative activation time :math:`t_{\mathrm{start,rel}}` (typically
        ms), initialized via :func:`braintools.init.param`. The effective
        inclusive lower bound of the active window is ``origin + start``.
        Default is ``0. * u.ms``.
    stop : ArrayLike or None, optional
        Relative deactivation time :math:`t_{\mathrm{stop,rel}}` (typically
        ms), initialized via :func:`braintools.init.param` when not ``None``.
        The effective exclusive upper bound is ``origin + stop``. ``None``
        disables the upper bound. Default is ``None``.
    origin : ArrayLike, optional
        Global time origin :math:`t_0` (typically ms) added to both ``start``
        and ``stop`` to obtain absolute window bounds. Default is ``0. * u.ms``.
    name : str or None, optional
        Optional node name forwarded to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``spike_times``
         - ``()``
         - :math:`t_i`
         - Spike schedule; matched by ``|t - t_i| < dt/2``.
       * - ``spike_multiplicities``
         - ``()``
         - :math:`m_i`
         - Per-time spike count; empty means implicit :math:`m_i = 1`.
       * - ``start``
         - ``0. * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative inclusive lower bound of active window.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative exclusive upper bound; ``None`` means unbounded.
       * - ``origin``
         - ``0. * u.ms``
         - :math:`t_0`
         - Global offset applied to ``start`` and ``stop``.

    Raises
    ------
    ValueError
        If ``precise_times=True`` is combined with ``allow_offgrid_times=True``
        or ``shift_now_spikes=True``, if ``spike_times`` is not non-descending
        after conversion, or if ``spike_multiplicities`` is non-empty and has
        a different length than ``spike_times``.
    TypeError
        If :func:`u.math.asarray` detects unit inconsistency across entries,
        or if unitful/unitless arithmetic is invalid during time-window
        comparisons.
    KeyError
        At update time, if required simulation context entries (e.g. ``'t'``
        or ``dt``) are absent from ``brainstate.environ``.

    Notes
    -----
    This device does not accept incoming synaptic or current connections; it
    only emits scheduled events. The output is dimensionless (spike count per
    step) and is typically consumed by a downstream synapse model that scales
    by connection weight.

    The key behavioral difference from :class:`spike_generator` is
    *accumulation*: when two entries in ``spike_times`` round to the same
    step, ``spike_train_injector`` sums their multiplicities while
    ``spike_generator`` retains only the last matching weight. Use
    ``spike_train_injector`` when replaying recorded spike trains that may
    contain bursts, and ``spike_generator`` when a single weighted event per
    step is intended.

    Spike times should ideally be aligned to the simulation grid (multiples
    of ``dt``) to avoid off-by-one steps. The tolerance ``dt/2`` covers
    one-ULP rounding for grid-aligned times in typical float64 arithmetic.

    See Also
    --------
    spike_generator : Deterministic spike device with per-spike weights
        (last-match semantics).
    dc_generator : Constant-current stimulation device.
    ac_generator : Sinusoidal current stimulation device.
    step_current_generator : Piecewise-constant current stimulation device.

    References
    ----------
    .. [1] NEST Simulator, ``spike_train_injector`` device.
           https://nest-simulator.readthedocs.io/en/stable/models/spike_train_injector.html

    Examples
    --------
    Inject a burst of five spikes at ``t = 2 ms`` (two entries map to the same
    step, multiplicities are accumulated to give ``a = 2 + 3 = 5``):

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     inj = brainpy.state.spike_train_injector(
       ...         spike_times=[1.0 * u.ms, 2.0 * u.ms, 2.0 * u.ms],
       ...         spike_multiplicities=[1, 2, 3],
       ...         start=0.0 * u.ms,
       ...         stop=5.0 * u.ms,
       ...     )
       ...     with brainstate.environ.context(t=2.0 * u.ms):
       ...         out = inj.update()
       ...     _ = out.shape

    Inject a single spike at ``t = 10 ms`` using NEST's ``precise_times``
    flag for API compatibility (sub-step resolution not enforced here):

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     inj = brainpy.state.spike_train_injector(
       ...         spike_times=[10.0 * u.ms],
       ...         precise_times=True,
       ...     )
       ...     with brainstate.environ.context(t=10.0 * u.ms):
       ...         out = inj.update()
       ...     _ = out.shape
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        spike_times: Sequence = (),
        spike_multiplicities: Sequence = (),
        precise_times: bool = False,
        allow_offgrid_times: bool = False,
        shift_now_spikes: bool = False,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        # ---- Validate option flags ----
        if precise_times and (allow_offgrid_times or shift_now_spikes):
            raise ValueError(
                "Option precise_times cannot be set to True when either "
                "allow_offgrid_times or shift_now_spikes is set to True."
            )
        self.precise_times = precise_times
        self.allow_offgrid_times = allow_offgrid_times
        self.shift_now_spikes = shift_now_spikes

        # ---- Store spike times as a Quantity array ----
        # u.math.asarray validates unit consistency across all entries.
        # Plain floats are interpreted as milliseconds.
        if len(spike_times) > 0:
            self.spike_times = u.math.asarray(spike_times)
            if not isinstance(self.spike_times, u.Quantity):
                self.spike_times = self.spike_times * u.ms

            # Validate non-descending order.
            for i in range(1, len(self.spike_times)):
                if self.spike_times[i] < self.spike_times[i - 1]:
                    raise ValueError(
                        "spike_times must be sorted in non-descending order. "
                        f"Got {self.spike_times[i - 1]} > {self.spike_times[i]} at index {i}."
                    )
        else:
            self.spike_times = None

        # ---- Validate and store spike multiplicities as a JAX array ----
        if len(spike_multiplicities) > 0 and len(spike_multiplicities) != len(spike_times):
            raise ValueError(
                "spike_multiplicities must have the same number of elements "
                "as spike_times or 0 elements to clear the property. "
                f"Got {len(spike_multiplicities)} and {len(spike_times)}."
            )
        if len(spike_multiplicities) > 0:
            self.spike_multiplicities = u.math.asarray([int(m) for m in spike_multiplicities], dtype=float)
        else:
            self.spike_multiplicities = None

        # ---- Device window parameters ----
        self.start = braintools.init.param(start, self.varshape)
        self.stop = None if stop is None else braintools.init.param(stop, self.varshape)
        self.origin = braintools.init.param(origin, self.varshape)

[docs] def update(self): r"""Compute the accumulated spike output for the current simulation step. The implementation is fully compatible with ``jax.jit``: spike-time matching uses :func:`u.math.searchsorted` on the static ``spike_times`` array while ``t`` and ``dt`` remain traced values throughout. The multiplicity sum uses a Boolean mask with no Python branching over traced values. Returns ------- out : jax.Array Float-valued JAX array with shape ``self.varshape``. Output semantics: - ``0`` when outside ``[origin + start, origin + stop)`` (or ``[origin + start, +inf)`` if ``stop is None``), - ``0`` when active but no configured spike satisfies ``|t - t_i| < dt/2``, - accumulated integer multiplicity :math:`a(t) = \sum_i m_i\, q_i(t)` when active and one or more spikes match. Raises ------ KeyError If required simulation context entries are missing from ``brainstate.environ`` (e.g. ``'t'`` or ``dt``). Notes ----- Matching uses the open interval :math:`(t - \Delta t/2,\, t + \Delta t/2)` located via two :func:`u.math.searchsorted` calls: - ``idx_lo = searchsorted(times, t - dt/2, side='right')`` — first index strictly greater than the lower bound. - ``idx_hi = searchsorted(times, t + dt/2, side='left')`` — first index at or above the upper bound. A Boolean mask ``indices in [idx_lo, idx_hi)`` selects all matching entries; their multiplicities (or 1s if none configured) are summed to obtain the scalar count :math:`a(t)`. Start is inclusive and stop is exclusive, matching NEST semantics. Unlike :meth:`spike_generator.update`, which keeps only the last matching weight, this method *accumulates* all matching multiplicities. A burst of three spikes scheduled at the same time thus returns ``3`` (or the sum of their individual multiplicities). See Also -------- spike_train_injector : Class-level parameter definitions and equations. spike_generator.update : Weight-selection (last-match) update rule. dc_generator.update : Windowed constant-current update rule. step_current_generator.update : Windowed piecewise-constant update rule. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() zeros = u.math.zeros(self.varshape) if self.spike_times is None: return zeros # Locate the open interval (t - dt/2, t + dt/2) via two searchsorted calls. # idx_lo: first index where spike_times > t - dt/2 (side='right') # idx_hi: first index where spike_times >= t + dt/2 (side='left') # Matching range is [idx_lo, idx_hi). idx_lo = u.math.searchsorted(self.spike_times, t - dt / 2, side='right') idx_hi = u.math.searchsorted(self.spike_times, t + dt / 2, side='left') # Build a Boolean mask over all K spike indices for the matching range. K = self.spike_times.shape[0] indices = u.math.arange(K) in_range = u.math.logical_and(indices >= idx_lo, indices < idx_hi) # Sum multiplicities (or 1s) over all matching indices. if self.spike_multiplicities is not None: spike_val = u.math.sum(u.math.where(in_range, self.spike_multiplicities, 0.0)) else: spike_val = u.math.sum(in_range.astype(float)) # NEST-compatible half-open activity window [origin+start, origin+stop). t_start = self.origin + self.start if self.stop is not None: t_stop = self.origin + self.stop active = u.math.logical_and(t >= t_start, t < t_stop) else: active = t >= t_start return u.math.where(active, spike_val * u.math.ones(self.varshape), zeros)