Source code for brainpy_state._nest.step_rate_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Sequence

import brainstate
import braintools
import saiunit as u
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice
from ._utils import stack_schedule_values

__all__ = [
    'step_rate_generator',
]


class step_rate_generator(NESTDevice):
    r"""Piecewise-constant rate generator -- NEST-compatible stimulation device.

    Generate a deterministic piecewise-constant rate trace and gate it with a
    half-open activity window using NEST-compatible parameter semantics.

    **1. Model equations and schedule selection**

    Let :math:`\{(t_k, a_k)\}_{k=1}^{K}` be configured change-time/rate pairs,
    where :math:`t_k` are times in ms and :math:`a_k` are rates in spikes/s
    (Hz). The scheduled rate is

    .. math::

        A(t) =
        \begin{cases}
            0, & t < t_1, \\
            a_k, & t_k \le t < t_{k+1},\ k=1,\dots,K-1, \\
            a_K, & t \ge t_K.
        \end{cases}

    The output is gated by

    .. math::

        g(t) = \mathbf{1}\!\left[t \ge t_0+t_{\mathrm{start,rel}}\right]
        \cdot
        \mathbf{1}\!\left[t < t_0+t_{\mathrm{stop,rel}}\right],

    with the second indicator omitted when ``stop is None``. Final output:

    .. math::

        r_{\mathrm{out}}(t) = g(t)\,A(t).

    **2. Timing semantics, assumptions, and constraints**

    This implementation chooses, at environment time ``t``, the latest
    schedule entry satisfying ``t_k <= t``. With discrete simulation time on a
    grid, this reproduces NEST-compatible step semantics where a configured
    change time marks the first step at which the new rate is emitted.

    Enforced constraints:

    - ``len(amplitude_times) == len(amplitude_values)``.
    - ``amplitude_times`` are strictly increasing.

    Accepted but not additionally constrained:

    - Unitless ``amplitude_times`` are interpreted as ms.
    - Unitless ``amplitude_values`` are interpreted as spikes/s.
    - NEST documentation recommends positive change times; positivity is not
      explicitly enforced here.

    **3. Computational implications**

    Each :meth:`update` call uses :func:`u.math.searchsorted` to find the
    active plateau, then selects the pre-broadcast rate array for
    ``self.varshape`` and applies one boolean activity mask. Per-call
    complexity is :math:`O(\log K + \prod \mathrm{varshape})`, where
    :math:`K` is the number of schedule entries.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape specification consumed by
        :class:`brainstate.nn.Dynamics`. The emitted rate has shape
        ``self.varshape`` derived from ``in_size``. Default is ``1``.
    amplitude_times : Sequence, optional
        Ordered sequence of change times with length ``K``. Each value may be
        a unitful time (typically ms) or a unitless numeric. Passed directly to
        :func:`u.math.asarray`, which validates unit consistency across all
        entries. Must be strictly increasing. Default is ``()``.
    amplitude_values : Sequence, optional
        Sequence of plateau rates with length ``K`` matching
        ``amplitude_times`` elementwise. Values represent spikes/s (Hz) and
        may be unitful or unitless. Each entry is converted via
        :func:`u.math.asarray` and expanded to the maximum ndim found across
        all entries (by prepending size-1 axes); the results are stacked to a
        shape that is broadcastable to ``(K, *varshape)``. Default is ``()``.
    start : ArrayLike, optional
        Relative start time :math:`t_{\mathrm{start,rel}}` (typically ms),
        broadcast to ``self.varshape`` via :func:`braintools.init.param`.
        Effective lower bound is ``origin + start`` (inclusive).
        Default is ``0. * u.ms``.
    stop : ArrayLike or None, optional
        Relative stop time :math:`t_{\mathrm{stop,rel}}` (typically ms),
        broadcast to ``self.varshape`` when provided. Effective upper bound is
        ``origin + stop`` (exclusive). ``None`` means no upper bound.
        Default is ``None``.
    origin : ArrayLike, optional
        Time origin :math:`t_0` (typically ms) added to ``start`` and ``stop``,
        broadcast to ``self.varshape``. Default is ``0. * u.ms``.
    name : str or None, optional
        Optional node name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``amplitude_times``
         - ``()``
         - :math:`t_k`
         - Change times for piecewise-constant rate plateaus.
       * - ``amplitude_values``
         - ``()``
         - :math:`a_k`
         - Plateau rates (spikes/s) selected at and after each ``t_k``.
       * - ``start``
         - ``0. * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative inclusive lower bound of the active output window.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative exclusive upper bound of the active output window.
       * - ``origin``
         - ``0. * u.ms``
         - :math:`t_0`
         - Global time offset added to ``start`` and ``stop``.

    Raises
    ------
    ValueError
        If ``amplitude_times`` and ``amplitude_values`` lengths differ, or if
        ``amplitude_times`` is not strictly increasing.
    TypeError
        If :func:`u.math.asarray` detects unit inconsistency across entries,
        or if unitful/unitless arithmetic is invalid during broadcasting or
        time-window comparisons.
    KeyError
        At update time, if simulation time ``'t'`` is missing from
        ``brainstate.environ``.

    Notes
    -----
    NEST recommends specifying ``amplitude_times`` on a grid of simulation
    resolution ``dt``. Using off-grid change times is allowed but may shift
    the effective change by up to one ``dt`` step depending on floating-point
    rounding when comparing ``t >= amp_time``. Use ``dc_generator``
    when only a constant current drive is needed; use ``step_rate_generator``
    when a rate-coded drive must take different values at different simulation
    intervals. Unlike ``step_current_generator``, the emitted quantity is
    dimensionless (spikes/s) and is not multiplied by a unit before output.

    See Also
    --------
    step_current_generator : Piecewise-constant current stimulation device.
    dc_generator : Constant current stimulation device.
    ac_generator : Sinusoidal current stimulation device.

    References
    ----------
    .. [1] NEST Simulator documentation for ``step_rate_generator``:
           https://nest-simulator.readthedocs.io/en/stable/models/step_rate_generator.html

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.step_rate_generator(
       ...         amplitude_times=[10.0 * u.ms, 110.0 * u.ms, 210.0 * u.ms],
       ...         amplitude_values=[400.0, 1000.0, 200.0],
       ...         start=0.0 * u.ms,
       ...         stop=300.0 * u.ms,
       ...     )
       ...     with brainstate.environ.context(t=160.0 * u.ms):
       ...         rate = gen.update()
       ...     _ = rate.shape

    .. code-block:: python

       >>> import brainpy
       >>> import saiunit as u
       >>> gen1 = brainpy.state.step_rate_generator(
       ...     amplitude_times=[0.0 * u.ms, 100.0 * u.ms, 200.0 * u.ms],
       ...     amplitude_values=[50.0, 0.0, 80.0],
       ... )
       >>> gen2 = brainpy.state.step_rate_generator(
       ...     in_size=10,
       ...     amplitude_times=[50.0 * u.ms, 150.0 * u.ms],
       ...     amplitude_values=[120.0, 40.0],
       ...     start=40.0 * u.ms,
       ...     stop=180.0 * u.ms,
       ...     origin=10.0 * u.ms,
       ... )
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        amplitude_times: Sequence = (),
        amplitude_values: Sequence = (),
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        if len(amplitude_times) != len(amplitude_values):
            raise ValueError(
                "amplitude_times and amplitude_values must have the same length. "
                f"Got {len(amplitude_times)} and {len(amplitude_values)}."
            )
        assert len(amplitude_times) > 0, "At least one schedule entry is required. Got len(amplitude_times) = 0."

        # Store amplitude_times as a Quantity array; u.math.asarray validates
        # that all entries share a consistent unit.
        # Shape: (K,)
        self.amplitude_times = u.math.asarray(amplitude_times)

        # Validate strictly increasing times before storing.
        for i in range(1, len(self.amplitude_times)):
            if self.amplitude_times[i] <= self.amplitude_times[i - 1]:
                raise ValueError(
                    "amplitude_times must be strictly increasing. "
                    f"Got {self.amplitude_times[i - 1]} >= {self.amplitude_times[i]} at index {i}."
                )

        self.amplitude_values = stack_schedule_values(amplitude_values, self.varshape)

        self.start = braintools.init.param(start, self.varshape)
        self.stop = None if stop is None else braintools.init.param(stop, self.varshape)
        self.origin = braintools.init.param(origin, self.varshape)

[docs] def update(self): r"""Compute scheduled rate at environment time ``t``. The implementation is fully compatible with ``jax.jit``: the schedule look-up uses :func:`u.math.searchsorted` on the static ``amplitude_times`` array, while ``t`` remains a traced value throughout. Returns ------- out : jax.Array Dimensionless rate array with shape ``self.varshape`` and values in spikes/s. For each output channel, value equals the latest scheduled plateau whose change time is ``<= t``. Channels outside the active window ``[origin + start, origin + stop)`` are set to zero (or ``t >= origin + start`` when ``stop is None``). Raises ------ KeyError If ``brainstate.environ`` has no ``'t'`` entry. Notes ----- Both ``amplitude_times`` and ``t`` are divided by ``u.ms`` to obtain dimensionless arrays before calling :func:`u.math.searchsorted`. ``u.math.searchsorted(..., side='right') - 1`` returns the index of the most-recently-passed change point, or ``-1`` when ``t`` precedes all change times (zero rate). :func:`u.math.clip` keeps the index in bounds for the gather; :func:`u.math.where` then suppresses the result when the index is negative. Start is inclusive and stop is exclusive, matching NEST semantics. See Also -------- step_rate_generator : Class-level parameter definitions and model equations. step_current_generator.update : Windowed piecewise-constant current update rule. dc_generator.update : Windowed constant-current update rule. """ t = brainstate.environ.get('t') # zeros has shape varshape so that u.math.where always broadcasts the # selected rate value to the full output shape. zeros = u.math.zeros(self.varshape, unit=u.get_unit(self.amplitude_values)) if len(self.amplitude_times) == 0: # No schedule entries: output is always zero. return zeros # Divide both by u.ms to obtain dimensionless arrays for searchsorted. # amplitude_times is a static array (compile-time constant under jit); # t_dimless is the only traced value in the look-up. t_dimless = u.math.asarray(t / u.ms) times_dimless = u.math.asarray(self.amplitude_times / u.ms) # Last index k such that amplitude_times[k] <= t, or -1 if none. idx = u.math.searchsorted(times_dimless, t_dimless, side='right') - 1 # Clamp to a valid index for the gather (idx=-1 is handled by where). safe_idx = u.math.clip(idx, 0, self.amplitude_values.shape[0] - 1) # amplitude_values has shape (K, *broadcast_shape); indexing with a scalar # safe_idx yields shape (*broadcast_shape,) broadcastable to varshape. rate = u.math.where(idx >= 0, self.amplitude_values[safe_idx], zeros) # NEST-compatible half-open activity window [origin+start, origin+stop). t_start = self.origin + self.start if self.stop is not None: t_stop = self.origin + self.stop active = u.math.logical_and(t >= t_start, t < t_stop) else: active = t >= t_start return u.math.where(active, rate, u.math.zeros_like(rate))