Source code for brainpy_state._nest.spike_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Sequence

import brainstate
import braintools
import saiunit as u
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest._base import NESTDevice

__all__ = [
    'spike_generator',
]


class spike_generator(NESTDevice):
    r"""Spike generator -- NEST-compatible stimulation device.

    Emit deterministic spike-like outputs at prescribed times with optional
    per-event amplitudes, while respecting a half-open activity window.

    **1. Model equations**

    Let :math:`\{t_i\}_{i=1}^{K}` be configured spike times in ms
    (non-descending after conversion), and :math:`\{w_i\}_{i=1}^{K}` optional
    spike weights. At simulation time :math:`t` with step :math:`\Delta t`
    (both in ms), define the matching indicator

    .. math::

        m_i(t) = \mathbf{1}\!\left[|t - t_i| < \frac{\Delta t}{2}\right].

    The active-window gate is

    .. math::

        g(t) = \mathbf{1}\!\left[t \ge t_0 + t_{\mathrm{start,rel}}\right]
        \cdot
        \mathbf{1}\!\left[t < t_0 + t_{\mathrm{stop,rel}}\right],

    where the second indicator is omitted when ``stop is None``.

    This implementation computes a scalar amplitude :math:`a(t)` as follows:

    - no ``spike_weights``: :math:`a(t)=1` if any :math:`m_i(t)=1`, else
      :math:`a(t)=0`;
    - with ``spike_weights``: :math:`a(t)` equals the weight associated with
      the *last* matching index (iteration order through ``spike_times``).

    The returned output is broadcast to ``self.varshape``:

    .. math::

        y(t) = g(t)\,a(t)\,\mathbf{1}_{\mathrm{varshape}}.

    **2. Timing semantics, assumptions, and constraints**

    A configured spike at :math:`t_s` is intended for the step satisfying
    :math:`t_s-\Delta t < t \le t_s` under grid-aligned simulation. The
    implementation uses :math:`|t-t_s| < \Delta t/2` for robust floating-point
    matching, which is equivalent to :math:`t - \Delta t/2 < t_s < t + \Delta
    t/2`.

    Enforced constraints:

    - ``spike_times`` must be sorted in non-descending order after conversion.
    - ``spike_weights`` must be empty or have exactly
      ``len(spike_times)`` elements.

    Accepted but not additionally constrained:

    - Unitless ``spike_times`` are interpreted as ms.
    - Duplicate spike times are allowed. Without weights, duplicates remain
      binary output. With weights, the last duplicate's weight is used.

    **3. Computational implications**

    Each :meth:`update` call uses :func:`u.math.searchsorted` to locate the
    spike-time range matching the current step, then selects the last matching
    weight with :func:`u.math.clip` and :func:`u.math.where`. Per-call
    complexity is :math:`O(\log K + \prod\mathrm{varshape})`, where :math:`K`
    is the number of configured spike times.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape specification consumed by
        :class:`brainstate.nn.Dynamics`. The emitted array has shape
        ``self.varshape`` derived from ``in_size``. Default is ``1``.
    spike_times : Sequence, optional
        Sequence of spike times with length ``K``. Entries may be unitful
        times (typically ms) or unitless numerics interpreted as ms. Passed
        directly to :func:`u.math.asarray`, which validates unit consistency
        across all entries. Must be non-descending. Default is ``()``.
    spike_weights : Sequence, optional
        Optional sequence of per-spike amplitudes with length ``K`` matching
        ``spike_times`` exactly, or empty to use binary spikes. Entries are
        passed to :func:`u.math.asarray` (dimensionless). Default is ``()``.
    start : ArrayLike, optional
        Relative activation time :math:`t_{\mathrm{start,rel}}` (typically ms),
        initialized through :func:`braintools.init.param`. Effective lower
        bound is ``origin + start`` (inclusive). Default is ``0. * u.ms``.
    stop : ArrayLike or None, optional
        Relative deactivation time :math:`t_{\mathrm{stop,rel}}` (typically
        ms), initialized through :func:`braintools.init.param` when provided.
        Effective upper bound is ``origin + stop`` (exclusive). ``None`` means
        no upper bound. Default is ``None``.
    origin : ArrayLike, optional
        Global time origin :math:`t_0` (typically ms) added to ``start`` and
        ``stop``, broadcast to ``self.varshape``. Default is ``0. * u.ms``.
    name : str or None, optional
        Optional node name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``spike_times``
         - ``()``
         - :math:`t_i`
         - Scheduled spike times in ms, checked by ``|t - t_i| < dt/2``.
       * - ``spike_weights``
         - ``()``
         - :math:`w_i`
         - Per-spike amplitude; when multiple indices match, the last wins.
       * - ``start``
         - ``0. * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative inclusive lower bound of active window.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative exclusive upper bound of active window.
       * - ``origin``
         - ``0. * u.ms``
         - :math:`t_0`
         - Global offset applied to ``start`` and ``stop``.

    Raises
    ------
    ValueError
        If ``spike_times`` is not non-descending, or if
        ``len(spike_weights)`` is non-zero and differs from
        ``len(spike_times)``.
    TypeError
        If :func:`u.math.asarray` detects unit inconsistency across entries,
        or if unitful/unitless arithmetic is invalid during time-window
        comparisons.
    KeyError
        At update time, if simulation context lacks ``'t'`` or ``dt`` in
        ``brainstate.environ``.

    Notes
    -----
    Unlike current generators (``dc_generator``, ``step_current_generator``),
    ``spike_generator`` emits dimensionless impulses (or weighted real values)
    rather than physical current quantities. The output is intended to be
    consumed directly as pre-synaptic spike events or injected into a synapse
    model that scales by connection weight.

    NEST's ``spike_generator`` uses multiplicity to allow multiple spikes per
    time step; this implementation preserves that semantics — the last matching
    weight wins when duplicates exist. The :meth:`update` method is fully
    compatible with ``jax.jit``: both the spike-time lookup and the
    activity-window check use purely functional operations with no Python
    control flow over traced values.

    Spike times should ideally be aligned to the simulation grid (multiples of
    ``dt``) to avoid off-by-one steps due to floating-point comparison. The
    half-open tolerance ``dt/2`` generally covers one-ULP rounding errors for
    grid-aligned times.

    See Also
    --------
    dc_generator : Constant-current stimulation device.
    ac_generator : Sinusoidal current stimulation device.
    step_current_generator : Piecewise-constant current stimulation device.
    spike_train_injector : Inject pre-recorded spike trains into the network.

    References
    ----------
    .. [1] NEST Simulator, ``spike_generator`` device.
           https://nest-simulator.readthedocs.io/en/stable/models/spike_generator.html

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     sg = brainpy.state.spike_generator(
       ...         spike_times=[5.0 * u.ms, 10.0 * u.ms, 15.0 * u.ms],
       ...     )
       ...     with brainstate.environ.context(t=10.0 * u.ms):
       ...         spk = sg.update()
       ...     _ = spk.shape

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     sg = brainpy.state.spike_generator(
       ...         spike_times=[5.0 * u.ms, 5.0 * u.ms, 10.0 * u.ms],
       ...         spike_weights=[0.25, 0.5, 2.0],
       ...     )
       ...     with brainstate.environ.context(t=5.0 * u.ms):
       ...         spk = sg.update()
       ...     _ = spk.shape
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        spike_times: Sequence = (),
        spike_weights: Sequence = (),
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        if len(spike_weights) > 0 and len(spike_weights) != len(spike_times):
            raise ValueError(
                "spike_weights must have the same length as spike_times "
                f"or be empty. Got {len(spike_weights)} and {len(spike_times)}."
            )

        # Store spike_times as a Quantity array; u.math.asarray validates
        # that all entries share a consistent unit.
        # Shape: (K,)
        if len(spike_times) > 0:
            self.spike_times = u.math.asarray(spike_times)

            # Validate non-descending order.
            for i in range(1, len(self.spike_times)):
                if self.spike_times[i] < self.spike_times[i - 1]:
                    raise ValueError(
                        "spike_times must be sorted in non-descending order. "
                        f"Got {self.spike_times[i - 1]} > {self.spike_times[i]} at index {i}."
                    )
        else:
            self.spike_times = None

        # Store spike weights as a dimensionless array, or None for binary mode.
        self.spike_weights = u.math.asarray(spike_weights) if len(spike_weights) > 0 else None

        self.start = braintools.init.param(start, self.varshape)
        self.stop = None if stop is None else braintools.init.param(stop, self.varshape)
        self.origin = braintools.init.param(origin, self.varshape)

[docs] def update(self): r"""Compute spike output for the current simulation step. The implementation is fully compatible with ``jax.jit``: spike-time matching uses :func:`u.math.searchsorted` on the static ``spike_times`` array while ``t`` and ``dt`` remain traced values throughout. The activity-window check uses :func:`u.math.logical_and` and :func:`u.math.where` with no Python branching over traced values. Returns ------- out : jax.Array Float-valued JAX array with shape ``self.varshape``. Output semantics: - ``0`` when outside ``[origin + start, origin + stop)`` (or ``[origin + start, +inf)`` if ``stop is None``), - ``0`` when active but no configured spike matches ``|t - t_i| < dt/2``, - ``1`` at a matching spike time without weights, - last matching weight when ``spike_weights`` is configured. Raises ------ KeyError If required simulation context values are missing from ``brainstate.environ`` (e.g. ``'t'`` or ``dt``). Notes ----- Both ``spike_times`` and ``t`` are divided by ``u.ms`` to obtain dimensionless arrays before calling :func:`u.math.searchsorted`. The matching condition ``|t - t_s| < dt/2`` is rewritten as the open interval ``(t - dt/2, t + dt/2)`` and located with two ``searchsorted`` calls: - ``idx_lo = searchsorted(times, t - dt/2, side='right')`` — first index strictly greater than the lower bound. - ``idx_hi = searchsorted(times, t + dt/2, side='left')`` — first index at or above the upper bound. Any spike exists when ``idx_hi > idx_lo``; the last matching spike index is ``idx_hi - 1``, clamped to a valid range for the gather. Start is inclusive and stop is exclusive, matching NEST semantics. See Also -------- spike_generator : Class-level parameter definitions and model equations. dc_generator.update : Windowed constant-current update rule. step_current_generator.update : Windowed piecewise-constant update rule. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() zeros = u.math.zeros(self.varshape) if self.spike_times is None: # No spike times configured: output is always zero. return zeros # Locate the open interval (t - dt/2, t + dt/2) via two searchsorted calls. # idx_lo: first index where spike_times > t - dt/2 (side='right') # idx_hi: first index where spike_times >= t + dt/2 (side='left') # Matching range is [idx_lo, idx_hi). idx_lo = u.math.searchsorted(self.spike_times, t - dt / 2, side='right') idx_hi = u.math.searchsorted(self.spike_times, t + dt / 2, side='left') any_match = idx_hi > idx_lo # Last matching spike index; clamped to [0, K-1] for safe gather. last_idx = u.math.clip(idx_hi - 1, 0, self.spike_times.shape[0] - 1) if self.spike_weights is not None: spike_val = u.math.where(any_match, self.spike_weights[last_idx], 0.0) else: spike_val = u.math.where(any_match, 1.0, 0.0) # NEST-compatible half-open activity window [origin+start, origin+stop). t_start = self.origin + self.start if self.stop is not None: t_stop = self.origin + self.stop active = u.math.logical_and(t >= t_start, t < t_stop) else: active = t >= t_start return u.math.where(active, spike_val * u.math.ones(self.varshape), zeros)