Source code for brainpy_state._nest.inhomogeneous_poisson_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import math
from typing import Sequence

import brainstate
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'inhomogeneous_poisson_generator',
]

_UNSET = object()


class inhomogeneous_poisson_generator(NESTDevice):
    r"""Inhomogeneous Poisson spike generator with NEST-compatible scheduling.

    Emit Poisson-distributed spike multiplicities from a piecewise-constant
    rate schedule and replicate NEST update ordering for future rate changes.

    **1. Stochastic model and one-step-ahead schedule semantics**

    Let :math:`\Delta t` be the simulation resolution in ms and
    :math:`n \in \mathbb{N}` the current step index with
    :math:`t_n = n \Delta t`. The generator maintains an internal rate
    :math:`r_n` in spikes/s. For each configured pair
    :math:`(t_k, v_k) =` ``(rate_times[k], rate_values[k])``, the requested
    time is aligned to a grid step :math:`s_k`:

    .. math::

       s_k =
       \begin{cases}
         \mathrm{round}(t_k / \Delta t), & \text{if representable on grid}, \\
         \left\lceil t_k / \Delta t \right\rceil, &
         \text{if off-grid and ``allow\_offgrid\_times`` is True}.
       \end{cases}

    During :meth:`update`, entries with :math:`s_k \le n` are skipped as past
    events. The next unapplied entry is consumed exactly when
    :math:`s_k = n + 1`, i.e., one simulation step ahead of delivery. This
    one-step-ahead convention reproduces NEST device ordering and avoids
    retroactive rate jumps.

    For active steps with :math:`r_n > 0`, per-output spike multiplicities are
    sampled independently as

    .. math::

       K_n \sim \mathrm{Poisson}(\lambda_n), \quad
       \lambda_n = \frac{r_n \,\Delta t}{1000},

    where the factor of 1000 converts Hz × ms to a dimensionless Poisson mean.
    Returned values are non-negative integers and may exceed 1 for high firing
    rates or large time steps.

    **2. Activity window, assumptions, and constraints**

    Activity is gated by the NEST spike-device convention using a
    half-open-on-the-left interval:

    .. math::

       t_{\min} < t_n \le t_{\max}, \quad
       t_{\min} = t_0 + t_{\mathrm{start,rel}},\;
       t_{\max} = t_0 + t_{\mathrm{stop,rel}}.

    Therefore, ``start`` is an exclusive lower bound and ``stop`` is an
    inclusive upper bound in timestamp space. If ``stop is None``,
    :math:`t_{\max} = +\infty` and no upper cutoff is applied.

    The following schedule constraints are enforced at :meth:`set` call time:

    - ``rate_times`` and ``rate_values`` must always be provided together.
    - Flattened lengths of both arrays must match after conversion.
    - Aligned schedule steps :math:`s_k` must form a strictly increasing
      sequence; duplicate grid positions are rejected.
    - Each configured rate time must lie strictly in the future relative to
      the environment time reported by ``brainstate.environ`` at the moment
      :meth:`set` is called.

    **3. Computational implications**

    Schedule preprocessing in :meth:`set` is :math:`O(K)`, where :math:`K` is
    the number of configured change points. The per-step :meth:`update` cost is
    :math:`O(M + \prod \mathrm{varshape})`, where :math:`M` is the number of
    outdated entries skipped in that call (amortized :math:`O(1)` over a
    full simulation). Poisson sampling is vectorized over ``self.varshape``
    via ``jax.random.poisson``, yielding statistically independent output
    trains for each element in the output array.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape specification for :class:`brainstate.nn.Dynamics`.
        ``self.varshape`` derived from ``in_size`` gives the shape of the
        sampled multiplicity array returned by each :meth:`update` call.
        Default is ``1``.
    rate_times : Sequence[ArrayLike] or ArrayLike or None, optional
        Scheduled rate-change times with logical shape ``(K,)``. Entries are
        interpreted as milliseconds and stored internally as a flattened
        ``np.ndarray`` with dtype ``float64`` after grid alignment. ``None``
        means no schedule is configured at construction time. Must be provided
        together with ``rate_values``. Default is ``None``.
    rate_values : Sequence[ArrayLike] or ArrayLike or None, optional
        Scheduled firing rates in spikes/s (Hz) paired one-to-one with
        ``rate_times``, logical shape ``(K,)``. Stored as a flattened
        ``np.ndarray`` with dtype ``float64``. Must be provided together with
        ``rate_times``. Default is ``None``.
    allow_offgrid_times : bool, optional
        Grid-alignment policy for ``rate_times`` entries that do not fall
        exactly on a simulation time step. If ``False``, any off-grid time
        raises :class:`ValueError`. If ``True``, off-grid times are aligned
        upward (ceiling) to the nearest representable grid step, subject to a
        small absolute tolerance of ``1e-12`` to absorb floating-point round-
        off. Default is ``False``.
    start : ArrayLike, optional
        Scalar relative start time :math:`t_{\mathrm{start,rel}}` in ms.
        Added to ``origin`` to form the exclusive lower bound of the active
        interval. Unitless scalars are treated as ms; :class:`saiunit.Quantity`
        values are converted automatically. Default is ``0. * u.ms``.
    stop : ArrayLike or None, optional
        Scalar relative stop time :math:`t_{\mathrm{stop,rel}}` in ms. Added
        to ``origin`` to form the inclusive upper bound of the active interval.
        ``None`` disables the upper bound (:math:`t_{\max} = +\infty`).
        Default is ``None``.
    origin : ArrayLike, optional
        Scalar time offset :math:`t_0` in ms applied to both ``start`` and
        ``stop``. Allows shifting the activity window without modifying the
        relative ``start``/``stop`` values. Default is ``0. * u.ms``.
    rng_seed : int, optional
        Integer seed used to initialize the ``jax.random.PRNGKey`` for Poisson
        sampling. Changing the seed produces a statistically independent output
        spike train for otherwise identical parameters. Default is ``0``.
    name : str or None, optional
        Optional human-readable name for the dynamics node passed to
        :class:`brainstate.nn.Dynamics`. Default is ``None``.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 24 18 20 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``rate_times``
         - ``None``
         - :math:`t_k`
         - Scheduled rate-change times, aligned to grid steps :math:`s_k`.
       * - ``rate_values``
         - ``None``
         - :math:`v_k`
         - Scheduled firing rates (spikes/s) applied when :math:`s_k = n + 1`.
       * - ``start``
         - ``0. * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower bound of the active interval.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound; ``None`` means no upper cutoff.
       * - ``origin``
         - ``0. * u.ms``
         - :math:`t_0`
         - Global time offset added to both ``start`` and ``stop``.
       * - ``allow_offgrid_times``
         - ``False``
         - —
         - Off-grid policy: strict grid validation or upward ceiling alignment.
       * - ``rng_seed``
         - ``0``
         - —
         - Seed for the JAX PRNG key used in Poisson sampling.

    Raises
    ------
    ValueError
        If ``stop < start`` at construction time; if ``rate_times`` and
        ``rate_values`` are not provided together; if their flattened lengths
        differ; if any configured time is not strictly in the future; if
        aligned grid steps are not strictly increasing; if an off-grid time
        is supplied while ``allow_offgrid_times`` is ``False``; or if any
        time-like parameter is not scalar-convertible.
    TypeError
        If unit conversion or numeric coercion fails for any time or rate
        input (e.g., incompatible ``saiunit.Quantity`` dimensions).
    KeyError
        At runtime during :meth:`update`, if the simulation context accessed
        via ``brainstate.environ`` is missing the required ``dt`` key.

    Notes
    -----
    - Output values are spike counts per step (``0, 1, 2, ...``), not binary
      spikes. High firing rates or large time steps may produce multiplicities
      greater than one.
    - Calling :meth:`set` with a new non-empty schedule atomically resets the
      internal schedule pointer to index 0, matching NEST setter semantics.
    - Calling :meth:`update` without a prior :meth:`init_state` call will
      lazily initialize state variables on the first invocation.
    - The ``rng_key`` state is split (not folded) at each call, so the Poisson
      samples are statistically independent across time steps and across
      different elements of ``self.varshape``.

    See Also
    --------
    poisson_generator : Homogeneous Poisson stimulation device.
    sinusoidal_poisson_generator : Sinusoidally modulated Poisson device.
    step_rate_generator : Piecewise-constant deterministic rate generator.

    References
    ----------
    .. [1] NEST Simulator model documentation: ``inhomogeneous_poisson_generator``.
           https://nest-simulator.readthedocs.io/en/stable/models/inhomogeneous_poisson_generator.html

    Examples
    --------
    Create a generator that fires at 800 Hz during ``(5, 20]`` ms then goes
    silent, and read out the per-neuron spike counts at step ``t = 6 ms``:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.inhomogeneous_poisson_generator(
       ...         in_size=4,
       ...         rate_times=[5.0 * u.ms, 20.0 * u.ms],
       ...         rate_values=[800.0 * u.Hz, 0.0 * u.Hz],
       ...         start=0.0 * u.ms,
       ...         stop=30.0 * u.ms,
       ...         rng_seed=7,
       ...     )
       ...     gen.init_state()
       ...     with brainstate.environ.context(t=6.0 * u.ms):
       ...         counts = gen.update()
       ...     _ = counts.shape  # (4,), dtype int64

    Allow off-grid rate times and inspect the aligned schedule via
    :meth:`get`:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.inhomogeneous_poisson_generator(
       ...         allow_offgrid_times=True,
       ...     )
       ...     gen.set(
       ...         rate_times=[1.23 * u.ms, 2.34 * u.ms],
       ...         rate_values=[10.0 * u.Hz, 20.0 * u.Hz],
       ...     )
       ...     params = gen.get()
       ...     # params['rate_times'] contains ceil-aligned ms values
       ...     _ = params['allow_offgrid_times']  # True
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        rate_times: Sequence[ArrayLike] | ArrayLike | None = None,
        rate_values: Sequence[ArrayLike] | ArrayLike | None = None,
        allow_offgrid_times: bool = False,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        rng_seed: int = 0,
        name: str | None = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.allow_offgrid_times = bool(allow_offgrid_times)
        self.start = self._to_scalar_time_ms(start)
        self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
        self.origin = self._to_scalar_time_ms(origin)
        self.rng_seed = int(rng_seed)

        if self.stop < self.start:
            raise ValueError('stop must be greater than or equal to start.')

        dftype = brainstate.environ.dftype()
        self._rate_times_ms = np.asarray([], dtype=dftype)
        self._rate_values_hz = np.asarray([], dtype=dftype)
        ditype = brainstate.environ.ditype()
        self._rate_steps = np.asarray([], dtype=ditype)

        if (rate_times is None) ^ (rate_values is None):
            raise ValueError('Rate times and values must be reset together.')
        if rate_times is not None:
            self.set(rate_times=rate_times, rate_values=rate_values)

    @staticmethod
    def _to_scalar_time_ms(value: ArrayLike) -> float:
        if isinstance(value, u.Quantity):
            dftype = brainstate.environ.dftype()
            arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Time parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_time_array_ms(values: Sequence[ArrayLike] | ArrayLike) -> np.ndarray:
        dftype = brainstate.environ.dftype()
        if not isinstance(values, u.Quantity):
            arr0 = np.asarray(values)
            if arr0.size == 0:
                return np.asarray([], dtype=dftype)
        if isinstance(values, u.Quantity):
            arr = values.to_decimal(u.ms)
        else:
            arr = u.math.asarray(values, dtype=dftype)
        return np.asarray(arr, dtype=dftype).reshape(-1)

    @staticmethod
    def _to_rate_array_hz(values: Sequence[ArrayLike] | ArrayLike) -> np.ndarray:
        dftype = brainstate.environ.dftype()
        if not isinstance(values, u.Quantity):
            arr0 = np.asarray(values)
            if arr0.size == 0:
                return np.asarray([], dtype=dftype)
        if isinstance(values, u.Quantity):
            arr = values.to_decimal(u.Hz)
        else:
            arr = u.math.asarray(values, dtype=dftype)
        return np.asarray(arr, dtype=dftype).reshape(-1)

    @staticmethod
    def _array_to_public(value: np.ndarray):
        if value.size == 1:
            return float(value[0])
        return value.tolist()

    @staticmethod
    def _time_to_step(time_ms: float, dt_ms: float) -> int:
        return int(np.rint(time_ms / dt_ms))

    def _dt_ms(self) -> float:
        dt = brainstate.environ.get_dt()
        return self._to_scalar_time_ms(dt)

    def _current_time_ms(self) -> float:
        t = brainstate.environ.get('t', default=0. * u.ms)
        if t is None:
            return 0.0
        return self._to_scalar_time_ms(t)

    def _align_rate_time_to_grid(self, time_ms: float, dt_ms: float) -> tuple[int, float]:
        ratio = time_ms / dt_ms
        nearest = np.rint(ratio)

        if math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
            step = int(nearest)
        elif self.allow_offgrid_times:
            step = int(math.ceil(ratio - 1e-12))
        else:
            raise ValueError(
                f'inhomogeneous_poisson_generator: Time point {time_ms} '
                f'is not representable in current resolution.'
            )

        return step, float(step) * dt_ms

[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Initialize transient schedule pointer and RNG state. Creates the three :class:`brainstate.ShortTermState` objects required by :meth:`update`: the schedule pointer ``_rate_idx`` (``int64`` scalar), the currently active firing rate ``_rate_hz`` (``float64`` scalar, initialized to ``0.0``), and the JAX PRNG key ``rng_key`` seeded from ``self.rng_seed``. This method is idempotent with respect to the configured schedule: the existing ``_rate_times_ms``, ``_rate_values_hz``, and ``_rate_steps`` arrays are left unchanged; only the runtime-mutable state variables are (re-)created. Parameters ---------- batch_size : int or None, optional Unused. Present only for :class:`brainstate.nn.Dynamics` API compatibility. Default is ``None``. **kwargs Additional keyword arguments accepted for API compatibility and silently ignored. """ del batch_size, kwargs ditype = brainstate.environ.ditype() self._rate_idx = brainstate.ShortTermState(jnp.asarray(0, dtype=ditype)) dftype = brainstate.environ.dftype() self._rate_hz = brainstate.ShortTermState(jnp.asarray(0.0, dtype=dftype)) self.rng_key = brainstate.ShortTermState(jax.random.PRNGKey(self.rng_seed))
[docs] def set( self, *, rate_times: Sequence[ArrayLike] | ArrayLike | object = _UNSET, rate_values: Sequence[ArrayLike] | ArrayLike | object = _UNSET, allow_offgrid_times: bool | object = _UNSET, ): r"""Update the rate schedule and/or off-grid policy with NEST-compatible validation. Replaces the current piecewise-constant rate schedule with a new one, optionally updating the off-grid alignment policy at the same time. All provided times are validated against the current simulation clock (must be strictly in the future), aligned to the simulation grid, and checked for strict monotonicity. Passing ``rate_times=[]`` and ``rate_values=[]`` clears the schedule: internal arrays are set to empty and the schedule pointer is reset to 0. Parameters ---------- rate_times : Sequence[ArrayLike] or ArrayLike, optional New rate-change times in ms. Inputs are flattened to shape ``(K,)`` and stored as ``np.ndarray[float64]`` after grid alignment. Must be provided together with ``rate_values``; omitting one while supplying the other raises :class:`ValueError`. If omitted entirely (sentinel ``_UNSET``), the existing schedule is left unchanged. rate_values : Sequence[ArrayLike] or ArrayLike, optional New firing rates in spikes/s (Hz) paired one-to-one with ``rate_times``. Stored as ``np.ndarray[float64]``. Must have exactly the same flattened length as ``rate_times``. allow_offgrid_times : bool, optional If supplied, updates ``self.allow_offgrid_times``. Changing this flag is only permitted when ``rate_times`` is also being set in the same call, or when no schedule has been configured yet. Attempting to change the flag with an existing non-empty schedule and without new times raises :class:`ValueError`. Raises ------ ValueError If exactly one of ``rate_times`` / ``rate_values`` is provided (must supply both or neither); if their flattened lengths differ; if ``allow_offgrid_times`` is changed while an existing non-empty schedule is in place without also providing new times; if any time value is not strictly greater than the current environment time; if any two adjacent aligned grid steps are not strictly increasing; or if a time is off-grid and ``allow_offgrid_times`` is ``False``. TypeError If unit conversion fails for ``rate_times`` or ``rate_values`` inputs (e.g., incompatible ``saiunit.Quantity`` dimensions). """ times_given = rate_times is not _UNSET rates_given = rate_values is not _UNSET if allow_offgrid_times is not _UNSET: new_flag = bool(allow_offgrid_times) if ( new_flag != self.allow_offgrid_times and not (times_given or self._rate_times_ms.size == 0) ): raise ValueError( 'Option can only be set together with rate times ' 'or if no rate times have been set.' ) self.allow_offgrid_times = new_flag if times_given ^ rates_given: raise ValueError('Rate times and values must be reset together.') if not (times_given or rates_given): return times_ms = self._to_time_array_ms(rate_times) values_hz = self._to_rate_array_hz(rate_values) if times_ms.size != values_hz.size: raise ValueError('Rate times and values have to be the same size.') dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() if times_ms.size == 0: self._rate_times_ms = np.asarray([], dtype=dftype) self._rate_values_hz = np.asarray([], dtype=dftype) self._rate_steps = np.asarray([], dtype=ditype) if hasattr(self, '_rate_idx'): self._rate_idx.value = jnp.asarray(0, dtype=ditype) return dt_ms = self._dt_ms() now_ms = self._current_time_ms() aligned_times = np.empty_like(times_ms, dtype=dftype) aligned_steps = np.empty_like(times_ms, dtype=ditype) for i, t_ms in enumerate(times_ms): if t_ms <= now_ms: raise ValueError('Time points must lie strictly in the future.') step, aligned_ms = self._align_rate_time_to_grid(float(t_ms), dt_ms) aligned_steps[i] = step aligned_times[i] = aligned_ms if i > 0 and aligned_steps[i - 1] >= aligned_steps[i]: raise ValueError('Rate times must be strictly increasing.') self._rate_times_ms = aligned_times self._rate_values_hz = values_hz self._rate_steps = aligned_steps # Match NEST setter semantics: schedule index is reset on new data. if hasattr(self, '_rate_idx'): self._rate_idx.value = jnp.asarray(0, dtype=ditype)
[docs] def get(self) -> dict: r"""Return current schedule and timing parameters in NEST-style format. Serializes all user-configurable generator parameters into a plain Python dict. This mirrors the ``nest.GetStatus`` interface so that parameter introspection and round-tripping via :meth:`set` / :meth:`get` work as expected. Returns ------- params : dict Dictionary with the following keys: - ``'rate_times'`` (``float`` or ``list[float]``): Grid-aligned rate-change times in ms. A single-entry schedule is returned as a bare ``float``; a multi-entry schedule as a Python ``list``. An empty schedule returns an empty ``list``. - ``'rate_values'`` (``float`` or ``list[float]``): Corresponding firing rates in spikes/s (Hz), same shape convention as ``'rate_times'``. - ``'allow_offgrid_times'`` (``bool``): Current off-grid alignment policy. - ``'start'`` (``float``): Relative exclusive lower activity bound in ms. - ``'stop'`` (``float``): Inclusive upper activity bound in ms, or ``float('inf')`` if no upper bound was set. - ``'origin'`` (``float``): Global time offset in ms. """ return { 'rate_times': self._array_to_public(self._rate_times_ms), 'rate_values': self._array_to_public(self._rate_values_hz), 'allow_offgrid_times': bool(self.allow_offgrid_times), 'start': float(self.start), 'stop': float(self.stop), 'origin': float(self.origin), }
def _is_active(self, curr_step: int, dt_ms: float) -> bool: t_ms = curr_step * dt_ms t_min = self.origin + self.start t_max = self.origin + self.stop return (t_min < t_ms) and (t_ms <= t_max) def _sample_poisson(self, lam: float) -> jax.Array: key, subkey = jax.random.split(self.rng_key.value) self.rng_key.value = key dftype = brainstate.environ.dftype() return jax.random.poisson( subkey, lam=jnp.asarray(lam, dtype=dftype), shape=self.varshape, ).astype(jnp.int64)
[docs] def update(self): r"""Advance one simulation step and emit Poisson spike multiplicities. Reads the current simulation time from ``brainstate.environ``, advances the schedule pointer past any entries whose grid step :math:`s_k \le n`, then applies the next scheduled rate change if :math:`s_k = n + 1`. When the generator is active (current time inside the activity window) and the current rate is positive, samples a Poisson multiplicity array over ``self.varshape``. Otherwise returns a zero array. Lazy initialization: if :meth:`init_state` has not been called, this method initializes state variables on the first invocation. Returns ------- spikes : jax.Array, shape ``self.varshape``, dtype ``int64`` Per-output Poisson spike multiplicity for the current time step. Each element :math:`K_i \sim \mathrm{Poisson}(\lambda_n)` where :math:`\lambda_n = r_n \Delta t / 1000`. Returns all-zero array when the generator is inactive or the current rate is zero. """ if not hasattr(self, '_rate_idx'): self.init_state() ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() # Extract dt and t as JAX values (traced-compatible for for_loop). dt = brainstate.environ.get_dt() if isinstance(dt, u.Quantity): dt_ms_jax = dt.to_decimal(u.ms) else: dt_ms_jax = jnp.asarray(dt, dtype=dftype) t = brainstate.environ.get('t', default=0. * u.ms) if t is None: t_ms_jax = jnp.asarray(0.0, dtype=dftype) elif isinstance(t, u.Quantity): t_ms_jax = t.to_decimal(u.ms) else: t_ms_jax = jnp.asarray(t, dtype=dftype) # curr_step as a JAX integer — works under both eager and JIT. curr_step = jnp.asarray(jnp.rint(t_ms_jax / dt_ms_jax), dtype=ditype) n_entries = self._rate_steps.size if n_entries > 0: rate_steps_jax = jnp.asarray(self._rate_steps, dtype=ditype) rate_values_jax = jnp.asarray(self._rate_values_hz, dtype=dftype) # Find first schedule index whose step > curr_step (skipping past entries). new_idx = jnp.searchsorted(rate_steps_jax, curr_step, side='right') # Clamp to valid range for safe indexing. safe_idx = jnp.minimum(new_idx, n_entries - 1) # Apply the next entry if it falls exactly one step ahead. next_step_val = rate_steps_jax[safe_idx] in_bounds = new_idx < n_entries applies_next = in_bounds & (next_step_val == curr_step + 1) new_rate = jnp.where(applies_next, rate_values_jax[safe_idx], self._rate_hz.value) final_idx = jnp.where(applies_next, new_idx + 1, new_idx) self._rate_idx.value = final_idx.astype(ditype) self._rate_hz.value = new_rate.astype(dftype) rate_hz = self._rate_hz.value # Activity gating using JAX comparisons (no Python bool — JIT-safe). t_min = jnp.asarray(self.origin + self.start, dtype=dftype) t_max = jnp.asarray(self.origin + self.stop, dtype=dftype) is_active = (t_ms_jax > t_min) & (t_ms_jax <= t_max) # Always sample (RNG key advances every step); gate result with jnp.where. lam = rate_hz * dt_ms_jax / 1000.0 samples = self._sample_poisson(lam) zeros = jnp.zeros(self.varshape, dtype=ditype) return jnp.where(is_active & (rate_hz > 0.0), samples, zeros)