Source code for brainpy_state._nest.sinusoidal_poisson_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import math

import brainstate
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'sinusoidal_poisson_generator',
]

_UNSET = object()


class sinusoidal_poisson_generator(NESTDevice):
    r"""Sinusoidally modulated Poisson spike generator compatible with NEST.

    Description
    -----------
    ``sinusoidal_poisson_generator`` re-implements NEST's stimulation device
    of the same name and emits per-step spike multiplicities.

    **1. Stochastic model and discretization**

    The instantaneous firing rate in spikes/s is

    .. math::

       f(t) = \max\left(
         0,\ r + a \sin\left( 2\pi f_{\mathrm{mod}} t / 1000 + \phi \right)
       \right),

    where:

    - :math:`r` is ``rate`` (spikes/s),
    - :math:`a` is ``amplitude`` (spikes/s),
    - :math:`f_{\mathrm{mod}}` is ``frequency`` (Hz),
    - :math:`\phi` is ``phase`` (deg, internally converted to radians),
    - :math:`t` is simulation time in ms.

    For simulation resolution :math:`\Delta t` in ms, each output train
    samples a Poisson multiplicity

    .. math::

       K_n \sim \mathrm{Poisson}(\lambda_n), \qquad
       \lambda_n = f_n \Delta t / 1000,

    where the ``1000`` factor converts Hz * ms to a dimensionless mean.
    ``K_n`` is an integer count ``0, 1, 2, ...`` and may exceed ``1``.

    **2. Oscillator-state recurrence and derivation**

    Following NEST, sinusoidal modulation is stored in a rotated two-component
    oscillator state:

    .. math::

       y_0(t) = a/1000 \cdot \cos(\omega t + \phi), \qquad
       y_1(t) = a/1000 \cdot \sin(\omega t + \phi),

    with :math:`\omega = 2\pi f_{\mathrm{mod}}/1000` (rad/ms). One-step
    propagation by :math:`\Delta t` uses a rotation matrix
    :math:`R(\omega\Delta t)`:

    .. math::

       \begin{bmatrix}
       y_0' \\
       y_1'
       \end{bmatrix}
       =
       \begin{bmatrix}
       \cos(\omega\Delta t) & -\sin(\omega\Delta t) \\
       \sin(\omega\Delta t) &  \cos(\omega\Delta t)
       \end{bmatrix}
       \begin{bmatrix}
       y_0 \\
       y_1
       \end{bmatrix}.

    The post-rotation ``y_1'`` is then added to ``rate/1000`` and clamped at
    zero before Poisson sampling. This avoids recomputing trigonometric
    functions each step and keeps per-step modulation update constant-time.

    **3. Update ordering (NEST source order)**

    The internal two-component oscillator state is updated exactly in the
    order used by NEST ``models/sinusoidal_poisson_generator.cpp``:

    1. Start from the DC component ``rate``.
    2. Rotate oscillator state ``(y_0, y_1)`` by one step.
    3. Add the rotated ``y_1`` to obtain instantaneous rate.
    4. Clamp rate at zero.
    5. Sample Poisson multiplicities if active.

    The per-step recorded ``rate`` value in NEST corresponds to this updated
    post-rotation rate. This implementation exposes it via
    :meth:`get_recorded_rate`.

    **4. Timing semantics**

    NEST currently classifies this model as ``CURRENT_GENERATOR`` in
    ``get_type()``. Consequently, activity is evaluated with a two-step shift
    in ``StimulationDevice::is_active``:

    .. math::

       t_{\min} < (n + 2) \le t_{\max},

    where ``n`` is current simulation step index and
    ``t_{\min} = \mathrm{origin} + \mathrm{start}``,
    ``t_{\max} = \mathrm{origin} + \mathrm{stop}`` (in steps).

    This differs from regular spike generators and is intentionally replicated
    here to match NEST behavior.

    **5. Assumptions, constraints, and computational implications**

    - Public parameters are scalar-only; non-scalar values raise
      :class:`ValueError`.
    - ``stop`` must satisfy ``stop >= start`` after unit conversion.
    - When ``dt`` is available, finite ``origin``, ``start``, and ``stop``
      must be representable on the simulation grid.
    - If ``dt`` changes, timing caches and oscillator state are recomputed from
      absolute simulation time to preserve NEST-compatible behavior.
    - Per-step complexity is :math:`O(\prod \mathrm{varshape})` for Poisson
      sampling and :math:`O(1)` for oscillator/timing updates.

    Parameters
    ----------
    in_size : Size, optional
        Output size specification for :class:`brainstate.nn.Dynamics`.
        The derived ``self.varshape`` is the shape of values returned by
        :meth:`update`; each element corresponds to one emitted train.
        Default is ``1``.
    rate : ArrayLike, optional
        Scalar baseline firing rate in spikes/s (Hz), shape ``()`` after
        conversion. Accepted inputs include scalar ``ArrayLike`` and
        :class:`saiunit.Quantity` convertible to ``u.Hz``.
        Default is ``0.0 * u.Hz``.
    amplitude : ArrayLike, optional
        Scalar sinusoidal modulation amplitude in spikes/s (Hz), shape ``()``
        after conversion. Units and conversion rules match ``rate``.
        Default is ``0.0 * u.Hz``.
    frequency : ArrayLike, optional
        Scalar modulation frequency in Hz, shape ``()`` after conversion.
        Internally converted to angular frequency in rad/ms.
        Default is ``0.0 * u.Hz``.
    phase : ArrayLike, optional
        Scalar modulation phase in degrees, shape ``()`` after conversion.
        Internally converted to radians.
        Default is ``0.0``.
    individual_spike_trains : bool, optional
        Sampling mode selector.
        If ``True``, Poisson sampling is independent for each index of
        ``self.varshape``.
        If ``False``, one sampled multiplicity is broadcast to all outputs.
        Default is ``True``.
    start : ArrayLike, optional
        Scalar relative activation start time in ms, shape ``()`` after
        conversion. Activity uses NEST current-generator semantics with a
        two-step shifted check. Default is ``0.0 * u.ms``.
    stop : ArrayLike or None, optional
        Scalar relative deactivation stop time in ms, shape ``()`` after
        conversion. ``None`` maps to ``+inf``.
        Must satisfy ``stop >= start`` after conversion.
        Default is ``None``.
    origin : ArrayLike, optional
        Scalar global time offset in ms, shape ``()`` after conversion.
        Added to ``start`` and ``stop`` for activity-window bounds.
        Default is ``0.0 * u.ms``.
    rng_seed : int, optional
        Seed used to initialize ``jax.random.PRNGKey`` in :meth:`init_state`
        and lazy initialization in :meth:`update`. Default is ``0``.
    name : str, optional
        Optional node name.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 24 18 20 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``rate``
         - ``0.0 * u.Hz``
         - :math:`r`
         - Baseline firing rate in spikes/s.
       * - ``amplitude``
         - ``0.0 * u.Hz``
         - :math:`a`
         - Sinusoidal modulation amplitude in spikes/s.
       * - ``frequency``
         - ``0.0 * u.Hz``
         - :math:`f_{\mathrm{mod}}`
         - Modulation frequency in Hz.
       * - ``phase``
         - ``0.0``
         - :math:`\phi`
         - Modulation phase in degrees (internally radians).
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative lower activity bound (NEST shifted semantics).
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative upper activity bound; ``None`` maps to ``+\infty``.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Global time offset applied to start/stop.
       * - ``in_size``
         - ``1``
         - -
         - Defines output train count/shape via ``self.varshape``.
       * - ``individual_spike_trains``
         - ``True``
         - -
         - Independent-per-output sampling vs shared broadcast sample.
       * - ``rng_seed``
         - ``0``
         - -
         - Seed for JAX random key evolution.

    Raises
    ------
    ValueError
        If any scalar-constrained parameter cannot be reduced to one value; if
        ``stop < start``; or if finite ``origin``/``start``/``stop`` are not
        representable on the simulation grid when ``dt`` is available.
    TypeError
        If numeric/unit conversion fails for provided rate/time inputs.
    KeyError
        At runtime, if required simulation context keys (for example ``dt`` in
        :meth:`update`) are unavailable through ``brainstate.environ``.

    Notes
    -----
    - Time parameters are validated on the simulation grid when ``dt`` is
      available, matching repository conventions used by other NEST-compatible
      generators.
    - The oscillator state is re-initialized from absolute simulation time
      whenever the simulation resolution changes, matching NEST pre-run
      calibration behavior.
    - Recorded rate from :meth:`get_recorded_rate` is the post-rotation,
      post-clamp value in spikes/s used for current-step sampling logic.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.sinusoidal_poisson_generator(
       ...         in_size=4,
       ...         rate=800.0 * u.Hz,
       ...         amplitude=200.0 * u.Hz,
       ...         frequency=10.0 * u.Hz,
       ...         phase=90.0,
       ...         start=5.0 * u.ms,
       ...         stop=50.0 * u.ms,
       ...         rng_seed=123,
       ...     )
       ...     with brainstate.environ.context(t=10.0 * u.ms):
       ...         counts = gen.update()
       ...     _ = counts.shape

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.sinusoidal_poisson_generator(
       ...         individual_spike_trains=False
       ...     )
       ...     gen.set(rate=500.0 * u.Hz, amplitude=300.0 * u.Hz, phase=45.0)
       ...     params = gen.get()
       ...     _ = params['rate'], params['amplitude']

    See Also
    --------
    poisson_generator : Homogeneous Poisson generator.
    inhomogeneous_poisson_generator : Piecewise-constant inhomogeneous Poisson generator.
    sinusoidal_gamma_generator : Sinusoidally modulated gamma-renewal generator.

    References
    ----------
    .. [1] NEST source:
           ``models/sinusoidal_poisson_generator.h`` and
           ``models/sinusoidal_poisson_generator.cpp``.
    .. [2] NEST source:
           ``nestkernel/stimulation_device.cpp``.
    .. [3] NEST docs:
           https://nest-simulator.readthedocs.io/en/stable/models/sinusoidal_poisson_generator.html
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        rate: ArrayLike = 0. * u.Hz,
        amplitude: ArrayLike = 0. * u.Hz,
        frequency: ArrayLike = 0. * u.Hz,
        phase: ArrayLike = 0.0,
        individual_spike_trains: bool = True,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        rng_seed: int = 0,
        name: str | None = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.rate = self._to_scalar_rate_hz(rate)
        self.amplitude = self._to_scalar_rate_hz(amplitude)
        self.frequency = self._to_scalar_rate_hz(frequency)
        self.phase = self._to_scalar_float(phase, name='phase')
        self.individual_spike_trains = bool(individual_spike_trains)

        self.start = self._to_scalar_time_ms(start)
        self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
        self.origin = self._to_scalar_time_ms(origin)
        self.rng_seed = int(rng_seed)

        if self.stop < self.start:
            raise ValueError('stop >= start required.')

        self._rate_per_ms = self.rate / 1000.0
        self._amplitude_per_ms = self.amplitude / 1000.0
        self._om_rad_per_ms = self.frequency * (2.0 * math.pi / 1000.0)
        self._phi_rad = self.phase * (math.pi / 180.0)

        self._dt_cache_ms = np.nan
        self._t_min_step = 0
        self._t_max_step = np.iinfo(np.int64).max
        self._sin_step = 0.0
        self._cos_step = 1.0

        dt_ms = self._maybe_dt_ms()
        if dt_ms is not None:
            self._refresh_timing_cache(dt_ms)
            self._refresh_step_rotation_cache(dt_ms)

    @staticmethod
    def _to_scalar_time_ms(value: ArrayLike) -> float:
        if isinstance(value, u.Quantity):
            dftype = brainstate.environ.dftype()
            arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Time parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_rate_hz(value: ArrayLike) -> float:
        if isinstance(value, u.Quantity):
            dftype = brainstate.environ.dftype()
            arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Rate parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
        dftype = brainstate.environ.dftype()
        arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError(f'{name} must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _time_to_step(time_ms: float, dt_ms: float) -> int:
        return int(np.rint(time_ms / dt_ms))

    @staticmethod
    def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
        if not np.isfinite(time_ms):
            return
        ratio = time_ms / dt_ms
        nearest = np.rint(ratio)
        if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be a multiple of the simulation resolution.')

    def _dt_ms(self) -> float:
        dt = brainstate.environ.get_dt()
        return self._to_scalar_time_ms(dt)

    def _maybe_dt_ms(self) -> float | None:
        dt = brainstate.environ.get('dt', default=None)
        if dt is None:
            return None
        return self._to_scalar_time_ms(dt)

    def _current_time_ms(self) -> float:
        t = brainstate.environ.get('t', default=0. * u.ms)
        if t is None:
            return 0.0
        return self._to_scalar_time_ms(t)

    def _refresh_timing_cache(self, dt_ms: float):
        self._assert_grid_time('origin', self.origin, dt_ms)
        self._assert_grid_time('start', self.start, dt_ms)
        self._assert_grid_time('stop', self.stop, dt_ms)

        self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
        if np.isfinite(self.stop):
            self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
        else:
            self._t_max_step = np.iinfo(np.int64).max
        self._dt_cache_ms = float(dt_ms)

    def _refresh_step_rotation_cache(self, dt_ms: float):
        self._sin_step = math.sin(dt_ms * self._om_rad_per_ms)
        self._cos_step = math.cos(dt_ms * self._om_rad_per_ms)

    def _reset_oscillator_state(self, t_ms: float):
        y0 = self._amplitude_per_ms * math.cos(self._om_rad_per_ms * t_ms + self._phi_rad)
        y1 = self._amplitude_per_ms * math.sin(self._om_rad_per_ms * t_ms + self._phi_rad)
        dftype = brainstate.environ.dftype()
        self.y_0.value = jnp.asarray(y0, dtype=dftype)
        self.y_1.value = jnp.asarray(y1, dtype=dftype)
        self._recorded_rate_hz.value = jnp.asarray(0.0, dtype=dftype)

    def _is_active(self, curr_step: int) -> bool:
        # Match NEST's current-generator activity handling for this model:
        # StimulationDevice::is_active uses step+2 for CURRENT_GENERATOR.
        shifted_step = curr_step + 2
        return (self._t_min_step < shifted_step) and (shifted_step <= self._t_max_step)

[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Initialize RNG, oscillator states, and cached recorded rate. Parameters ---------- batch_size : int or None, optional Unused placeholder for :class:`brainstate.nn.Dynamics` compatibility. **kwargs Unused extra keyword arguments. """ del batch_size, kwargs self.rng_key = brainstate.ShortTermState(jax.random.PRNGKey(self.rng_seed)) dftype = brainstate.environ.dftype() self.y_0 = brainstate.ShortTermState(jnp.asarray(0.0, dtype=dftype)) self.y_1 = brainstate.ShortTermState(jnp.asarray(0.0, dtype=dftype)) self._recorded_rate_hz = brainstate.ShortTermState(jnp.asarray(0.0, dtype=dftype)) dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_timing_cache(dt_ms) self._refresh_step_rotation_cache(dt_ms) self._reset_oscillator_state(self._current_time_ms())
[docs] def set( self, *, rate: ArrayLike | object = _UNSET, amplitude: ArrayLike | object = _UNSET, frequency: ArrayLike | object = _UNSET, phase: ArrayLike | object = _UNSET, individual_spike_trains: bool | object = _UNSET, start: ArrayLike | object = _UNSET, stop: ArrayLike | object = _UNSET, origin: ArrayLike | object = _UNSET, ): r"""Set public parameters and refresh dependent cached state. Parameters ---------- rate : ArrayLike or object, optional Scalar rate in spikes/s (Hz). ``_UNSET`` keeps current value. amplitude : ArrayLike or object, optional Scalar sinusoidal amplitude in spikes/s (Hz). ``_UNSET`` keeps current value. frequency : ArrayLike or object, optional Scalar frequency in Hz. ``_UNSET`` keeps current value. phase : ArrayLike or object, optional Scalar phase in degrees. ``_UNSET`` keeps current value. individual_spike_trains : bool or object, optional Sampling mode flag. ``_UNSET`` keeps current value. start : ArrayLike or object, optional Scalar relative start time in ms. ``_UNSET`` keeps current value. stop : ArrayLike, None, or object, optional Scalar relative stop time in ms, or ``None`` for ``+inf``. ``_UNSET`` keeps current value. origin : ArrayLike or object, optional Scalar origin time in ms. ``_UNSET`` keeps current value. Raises ------ ValueError If scalar conversion fails, ``stop < start``, or grid-time validation fails when ``dt`` is available. TypeError If unit or numeric conversion fails for supplied inputs. """ new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate) new_amplitude = ( self.amplitude if amplitude is _UNSET else self._to_scalar_rate_hz(amplitude) ) new_frequency = ( self.frequency if frequency is _UNSET else self._to_scalar_rate_hz(frequency) ) new_phase = self.phase if phase is _UNSET else self._to_scalar_float(phase, name='phase') new_individual = ( self.individual_spike_trains if individual_spike_trains is _UNSET else bool(individual_spike_trains) ) new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start) if stop is _UNSET: new_stop = self.stop elif stop is None: new_stop = np.inf else: new_stop = self._to_scalar_time_ms(stop) new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin) if new_stop < new_start: raise ValueError('stop >= start required.') self.rate = new_rate self.amplitude = new_amplitude self.frequency = new_frequency self.phase = new_phase self.individual_spike_trains = new_individual self.start = new_start self.stop = new_stop self.origin = new_origin self._rate_per_ms = self.rate / 1000.0 self._amplitude_per_ms = self.amplitude / 1000.0 self._om_rad_per_ms = self.frequency * (2.0 * math.pi / 1000.0) self._phi_rad = self.phase * (math.pi / 180.0) dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_timing_cache(dt_ms) self._refresh_step_rotation_cache(dt_ms) if hasattr(self, 'y_0'): self._reset_oscillator_state(self._current_time_ms())
[docs] def get(self) -> dict: r"""Return current public parameters and oscillator state snapshot. Returns ------- dict Dictionary with keys ``rate``, ``frequency``, ``phase``, ``amplitude``, ``individual_spike_trains``, ``start``, ``stop``, ``origin``, ``y_0``, and ``y_1``. Rates are in spikes/s, times are in ms, and oscillator states are in spikes/ms. """ y0 = 0.0 y1 = 0.0 if hasattr(self, 'y_0'): dftype = brainstate.environ.dftype() y0 = float(np.asarray(self.y_0.value, dtype=dftype).reshape(())) y1 = float(np.asarray(self.y_1.value, dtype=dftype).reshape(())) return { 'rate': float(self.rate), 'frequency': float(self.frequency), 'phase': float(self.phase), 'amplitude': float(self.amplitude), 'individual_spike_trains': bool(self.individual_spike_trains), 'start': float(self.start), 'stop': float(self.stop), 'origin': float(self.origin), 'y_0': y0, 'y_1': y1, }
[docs] def get_recorded_rate(self) -> float: r"""Return latest post-update instantaneous rate in spikes/s. Returns ------- float Most recent stored value of the post-rotation, post-clamp instantaneous rate. Returns ``0.0`` before state initialization. """ if not hasattr(self, '_recorded_rate_hz'): return 0.0 dftype = brainstate.environ.dftype() return float(np.asarray(self._recorded_rate_hz.value, dtype=dftype).reshape(()))
def _sample_poisson_individual(self, lam: float) -> jax.Array: key, subkey = jax.random.split(self.rng_key.value) self.rng_key.value = key dftype = brainstate.environ.dftype() return jax.random.poisson( subkey, lam=jnp.asarray(lam, dtype=dftype), shape=self.varshape, ).astype(jnp.int64) def _sample_poisson_shared(self, lam) -> jax.Array: key, subkey = jax.random.split(self.rng_key.value) self.rng_key.value = key dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() sample = jax.random.poisson( subkey, lam=jnp.asarray(lam, dtype=dftype), shape=(), ).astype(jnp.int64) return jnp.full(self.varshape, sample, dtype=ditype)
[docs] def update(self): r"""Advance generator by one simulation step and emit spike counts. Returns ------- jax.Array ``int64`` array with shape ``self.varshape``. Values are per-step spike multiplicities sampled from the configured sinusoidal Poisson process, or zeros when inactive/non-positive-rate. Raises ------ KeyError If required environment entries (for example ``dt``) are not available through ``brainstate.environ`` at runtime. ValueError If cached timing constraints become invalid after environment changes (for example non-grid-aligned finite time bounds). """ if not hasattr(self, 'rng_key'): self.init_state() dt_ms = self._dt_ms() if (not np.isfinite(self._dt_cache_ms)) or ( not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15) ): curr_t_ms = self._current_time_ms() self._refresh_timing_cache(dt_ms) self._refresh_step_rotation_cache(dt_ms) self._reset_oscillator_state(curr_t_ms) dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # Get current time as a JAX-compatible scalar so this method works under # jax.jit / brainstate.transform.for_loop tracing. t = brainstate.environ.get('t', default=0. * u.ms) if isinstance(t, u.Quantity): t_ms_jax = t.to_decimal(u.ms) else: t_ms_jax = jnp.asarray(t, dtype=dftype) curr_step_jax = jnp.rint(t_ms_jax / dt_ms).astype(jnp.int64) # Update oscillator state using JAX operations for JIT compatibility. cos_s = jnp.asarray(self._cos_step, dtype=dftype) sin_s = jnp.asarray(self._sin_step, dtype=dftype) y0 = self.y_0.value y1 = self.y_1.value new_y0 = cos_s * y0 - sin_s * y1 new_y1 = sin_s * y0 + cos_s * y1 rate_val = jnp.maximum( jnp.asarray(0.0, dtype=dftype), jnp.asarray(self._rate_per_ms, dtype=dftype) + new_y1, ) self.y_0.value = jnp.asarray(new_y0, dtype=dftype) self.y_1.value = jnp.asarray(new_y1, dtype=dftype) self._recorded_rate_hz.value = rate_val * jnp.asarray(1000.0, dtype=dftype) # Activity check using JAX comparisons for JIT compatibility. shifted_step = curr_step_jax + jnp.asarray(2, dtype=jnp.int64) t_min = jnp.asarray(self._t_min_step, dtype=jnp.int64) t_max = jnp.asarray(self._t_max_step, dtype=jnp.int64) active = jnp.logical_and(t_min < shifted_step, shifted_step <= t_max) positive_rate = rate_val > jnp.asarray(0.0, dtype=dftype) should_fire = jnp.logical_and(active, positive_rate) # Always sample (masking via jnp.where keeps this JIT-compatible). lam = rate_val * jnp.asarray(dt_ms, dtype=dftype) if self.individual_spike_trains: spikes = self._sample_poisson_individual(lam) else: spikes = self._sample_poisson_shared(lam) zeros = jnp.zeros(self.varshape, dtype=ditype) return jnp.where(should_fire, spikes, zeros)