Source code for brainpy_state._nest.ppd_sup_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import math

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'ppd_sup_generator',
]

_UNSET = object()


class ppd_sup_generator(NESTDevice):
    r"""Superposition of Poisson processes with dead time (NEST-compatible).

    Description
    -----------
    ``ppd_sup_generator`` re-implements NEST's stimulation device with the
    same name. For each output train, it emits the per-step multiplicity of a
    superposition of ``n_proc`` independent Poisson-like component processes
    with absolute dead time.

    **1. State model, derivation, and update equations**

    Let :math:`r=\mathrm{rate}` (Hz), :math:`\tau_d=\mathrm{dead\_time}` (ms),
    and :math:`\Delta t` be the simulation resolution in ms. For each output
    train, the internal state is an age-discretized occupancy model:

    - ``occ_active``: number of currently active component processes.
    - ``occ_refractory[a]`` for
      :math:`a=0,\dots,\lfloor\tau_d/\Delta t\rfloor-1`: number of processes
      in refractory age bin ``a``.
    - ``activate``: rotating pointer indicating the bin whose occupants become
      active at the current step.

    NEST's per-step hazard for one active process is

    .. math::

       h_{\mathrm{step}} =
       \frac{\Delta t}{1000/r-\tau_d}.

    This discretization is valid under NEST's model constraint
    :math:`1000/r>\tau_d` (or ``rate == 0``). If sinusoidal modulation is
    enabled, the instantaneous hazard becomes

    .. math::

       h_t = h_{\mathrm{step}}
       \left(1 + A \sin\left(2\pi f t / 1000\right)\right),

    where :math:`A=\mathrm{relative\_amplitude}\in[0,1]` and
    :math:`f=\mathrm{frequency}` in Hz.

    For each train and step, emitted multiplicity ``n_spikes`` is sampled from
    the active pool using NEST's branch logic:

    - Binomial branch: ``Binomial(occ_active, h_t)``.
    - Poisson approximation branch when
      ``(occ_active >= 100 and h_t <= 0.01)`` or
      ``(occ_active >= 500 and h_t * occ_active <= 0.1)``:
      sample ``Poisson(h_t * occ_active)`` and clip to ``occ_active``.

    State transition for nonzero refractory bins is

    .. math::

       occ\_active' = occ\_active + occ\_refractory[p] - n\_spikes,
       \quad
       occ\_refractory[p]' = n\_spikes,

    with pointer update :math:`p'=(p+1)\bmod B`,
    :math:`B=\lfloor\tau_d/\Delta t\rfloor`.
    If ``B == 0`` (zero dead time), the active pool is not decremented by
    refractory bookkeeping and each component can contribute at most one spike
    per step through the binomial/Poisson draw.

    **2. Timing semantics and activity window**

    Activity follows NEST ``StimulationDevice`` semantics for generators:

    .. math::

       t_{\min} < t \le t_{\max},
       \qquad
       t_{\min} = origin + start,\quad t_{\max} = origin + stop.

    Therefore ``start`` is exclusive and ``stop`` is inclusive. Internally,
    finite times are projected to steps with ``round(time_ms / dt_ms)`` and
    checked as ``t_min_step < curr_step <= t_max_step``.

    **3. Assumptions, constraints, and computational implications**

    All physical parameters are scalarized to host-side ``float64`` or
    ``int`` before simulation. Enforced constraints are:

    - ``dead_time >= 0``.
    - ``n_proc >= 1``.
    - ``relative_amplitude in [0, 1]``.
    - ``stop >= start``.
    - ``1000 / rate > dead_time`` (or ``rate == 0``).

    If ``dt`` is available, finite ``origin``, ``start``, and ``stop`` must be
    exact grid multiples (absolute tolerance ``1e-12`` in ``time/dt`` ratio).
    Runtime of :meth:`update` is
    :math:`O(\prod \mathrm{varshape})` per step; memory is
    :math:`O(\prod \mathrm{varshape} \cdot \lfloor\tau_d/\Delta t\rfloor)`.
    Random draws are produced by ``numpy.random.Generator`` (seeded by
    ``rng_seed``), so stochasticity is NumPy-host based rather than JAX-key
    based.

    Parameters
    ----------
    in_size : Size, optional
        Output size specification consumed by
        :class:`brainstate.nn.Dynamics`. ``self.varshape`` derived from this
        value is the exact shape returned by :meth:`update`, and each element
        corresponds to one independent output train. Default is ``1``.
    rate : ArrayLike, optional
        Scalar component-process rate in spikes/s (Hz), shape ``()`` after
        conversion. Accepts a single-element numeric ``ArrayLike`` or a
        :class:`saiunit.Quantity` convertible to ``u.Hz``.
        Must satisfy ``1000 / rate > dead_time`` when ``rate > 0``.
        Default is ``0.0 * u.Hz``.
    dead_time : ArrayLike, optional
        Scalar absolute refractory time in ms, shape ``()`` after conversion.
        Accepts a single-element numeric ``ArrayLike`` or a
        :class:`saiunit.Quantity` convertible to ``u.ms``.
        Must satisfy ``dead_time >= 0``. Default is ``0.0 * u.ms``.
    n_proc : ArrayLike, optional
        Scalar integer number of independent component processes per output
        train, shape ``()`` after conversion. Parsed by nearest-integer check
        with absolute tolerance ``1e-12``. Must satisfy ``n_proc >= 1``.
        Default is ``1``.
    frequency : ArrayLike, optional
        Scalar sinusoidal modulation frequency in Hz, shape ``()`` after
        conversion. ``frequency == 0`` disables sinusoidal variation even when
        ``relative_amplitude > 0``. Default is ``0.0 * u.Hz``.
    relative_amplitude : ArrayLike, optional
        Scalar dimensionless modulation amplitude :math:`A`, shape ``()``
        after conversion. Must satisfy ``0 <= relative_amplitude <= 1``.
        Default is ``0.0``.
    start : ArrayLike, optional
        Scalar relative activation time in ms, shape ``()`` after conversion.
        Effective lower activity bound is ``origin + start`` and is exclusive.
        Must be grid-representable when ``dt`` is available.
        Default is ``0.0 * u.ms``.
    stop : ArrayLike or None, optional
        Scalar relative deactivation time in ms, shape ``()`` after
        conversion. Effective upper activity bound is ``origin + stop`` and is
        inclusive. ``None`` maps to ``+inf``. Must satisfy ``stop >= start``
        and be grid-representable when finite and ``dt`` is available.
        Default is ``None``.
    origin : ArrayLike, optional
        Scalar time-origin offset in ms, shape ``()`` after conversion.
        Added to ``start`` and ``stop`` to compute absolute active bounds.
        Must be grid-representable when finite and ``dt`` is available.
        Default is ``0.0 * u.ms``.
    rng_seed : int, optional
        Seed used to initialize ``numpy.random.default_rng`` in
        :meth:`init_state`. Default is ``0``.
    name : str or None, optional
        Optional node name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 20 18 24 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``rate``
         - ``0.0 * u.Hz``
         - :math:`r`
         - Component-process rate in spikes/s.
       * - ``dead_time``
         - ``0.0 * u.ms``
         - :math:`\tau_d`
         - Absolute refractory duration in ms.
       * - ``n_proc``
         - ``1``
         - :math:`n_{\mathrm{proc}}`
         - Number of component processes per output train.
       * - ``frequency``
         - ``0.0 * u.Hz``
         - :math:`f`
         - Modulation frequency in Hz.
       * - ``relative_amplitude``
         - ``0.0``
         - :math:`A`
         - Relative sinusoidal modulation amplitude.
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower activity bound.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound; ``None`` maps to ``+\infty``.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Global offset added to ``start`` and ``stop``.
       * - ``in_size``
         - ``1``
         - -
         - Defines ``self.varshape`` for independent output trains.
       * - ``rng_seed``
         - ``0``
         - -
         - Seed for NumPy RNG used by stochastic transition draws.

    Raises
    ------
    ValueError
        If scalar conversion fails due to non-scalar inputs; if ``dead_time``
        is negative; if ``n_proc < 1``; if ``relative_amplitude`` is outside
        ``[0, 1]``; if ``stop < start``; if ``1000 / rate <= dead_time`` for
        nonzero ``rate``; if integer-valued inputs are non-integral beyond
        tolerance; or if finite ``origin``/``start``/``stop`` are not
        multiples of simulation resolution when ``dt`` is available.
    TypeError
        If conversion to ``u.Hz``/``u.ms`` or numeric casting fails for
        provided parameter types.
    KeyError
        At runtime, if required simulation-context fields (for example ``dt``
        used by ``brainstate.environ.get_dt()``) are unavailable.

    Notes
    -----
    - Initial occupancy matches NEST ``pre_run_hook()``:
      ``floor(rate / 1000 * n_proc * dt)`` in each refractory bin and the
      remainder in ``occ_active``.
    - NEST does not initialize to sinusoidal equilibrium, so modulation can
      show startup transients.
    - Stimulation-backend parameter order in NEST is
      ``[dead_time, rate, n_proc, frequency, relative_amplitude]``.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.ppd_sup_generator(
       ...         in_size=(2, 2),
       ...         rate=20.0 * u.Hz,
       ...         dead_time=2.0 * u.ms,
       ...         n_proc=80,
       ...         frequency=8.0 * u.Hz,
       ...         relative_amplitude=0.25,
       ...         start=5.0 * u.ms,
       ...         stop=50.0 * u.ms,
       ...         rng_seed=3,
       ...     )
       ...     with brainstate.environ.context(t=12.0 * u.ms):
       ...         counts = gen.update()
       ...     _ = counts.shape

    .. code-block:: python

       >>> import brainpy
       >>> import saiunit as u
       >>> gen = brainpy.state.ppd_sup_generator(rate=15.0 * u.Hz, n_proc=30)
       >>> gen.set(dead_time=1.5 * u.ms, stop=None, origin=2.0 * u.ms)
       >>> params = gen.get()
       >>> _ = params['dead_time'], params['stop']

    See Also
    --------
    gamma_sup_generator : Superposition of gamma-process component trains.
    sinusoidal_gamma_generator : Inhomogeneous gamma generator with sinusoidal rate modulation.
    poisson_generator : Independent Poisson trains without dead time.

    References
    ----------
    .. [1] NEST source: ``models/ppd_sup_generator.h`` and
           ``models/ppd_sup_generator.cpp``.
    .. [2] NEST docs:
           https://nest-simulator.readthedocs.io/en/stable/models/ppd_sup_generator.html
    .. [3] Deger M, Helias M, Boucsein C, Rotter S (2011).
           Statistical properties of superimposed stationary spike trains.
           Journal of Computational Neuroscience.
           https://doi.org/10.1007/s10827-011-0362-8
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        rate: ArrayLike = 0. * u.Hz,
        dead_time: ArrayLike = 0. * u.ms,
        n_proc: ArrayLike = 1,
        frequency: ArrayLike = 0. * u.Hz,
        relative_amplitude: ArrayLike = 0.0,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        rng_seed: int = 0,
        name: str | None = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.rate = self._to_scalar_rate_hz(rate)
        self.dead_time = self._to_scalar_time_ms(dead_time)
        self.n_proc = self._to_scalar_int(n_proc, name='n_proc')
        self.frequency = self._to_scalar_rate_hz(frequency)
        self.relative_amplitude = self._to_scalar_float(
            relative_amplitude,
            name='relative_amplitude',
        )
        self.start = self._to_scalar_time_ms(start)
        self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
        self.origin = self._to_scalar_time_ms(origin)
        self.rng_seed = int(rng_seed)

        self._validate_parameters(
            rate=self.rate,
            dead_time=self.dead_time,
            n_proc=self.n_proc,
            relative_amplitude=self.relative_amplitude,
            start=self.start,
            stop=self.stop,
        )

        self._num_targets = int(np.prod(self.varshape))
        self._hazard_step = 0.0
        self._omega_rad_per_ms = 0.0
        self._num_age_bins = 0
        self._dt_cache_ms = np.nan
        self._t_min_step = 0
        self._t_max_step = np.iinfo(np.int64).max

        dt_ms = self._maybe_dt_ms()
        if dt_ms is not None:
            self._refresh_runtime_cache(dt_ms)

    @staticmethod
    def _to_scalar_time_ms(value: ArrayLike) -> float:
        dftype = brainstate.environ.dftype()
        if isinstance(value, u.Quantity):
            arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Time parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_rate_hz(value: ArrayLike) -> float:
        dftype = brainstate.environ.dftype()
        if isinstance(value, u.Quantity):
            arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('rate must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
        dftype = brainstate.environ.dftype()
        arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError(f'{name} must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_int(value: ArrayLike, *, name: str) -> int:
        dftype = brainstate.environ.dftype()
        arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError(f'{name} must be scalar.')
        scalar = float(arr.reshape(()))
        nearest = np.rint(scalar)
        if not math.isclose(scalar, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be an integer.')
        return int(nearest)

    @staticmethod
    def _validate_parameters(
        *,
        rate: float,
        dead_time: float,
        n_proc: int,
        relative_amplitude: float,
        start: float,
        stop: float,
    ):
        if dead_time < 0.0:
            raise ValueError('The dead time cannot be negative.')

        inv_rate = np.inf if rate == 0.0 else (1000.0 / rate)
        if inv_rate <= dead_time:
            raise ValueError('The inverse rate has to be larger than the dead time.')

        if n_proc < 1:
            raise ValueError('The number of component processes cannot be smaller than one')

        if relative_amplitude < 0.0 or relative_amplitude > 1.0:
            raise ValueError('The relative amplitude of the rate modulation must be in [0,1].')

        if stop < start:
            raise ValueError('stop >= start required.')

    @staticmethod
    def _time_to_step(time_ms: float, dt_ms: float) -> int:
        return int(np.rint(time_ms / dt_ms))

    @staticmethod
    def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
        if not np.isfinite(time_ms):
            return
        ratio = time_ms / dt_ms
        nearest = np.rint(ratio)
        if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be a multiple of the simulation resolution.')

    def _dt_ms(self) -> float:
        dt = brainstate.environ.get_dt()
        return self._to_scalar_time_ms(dt)

    def _maybe_dt_ms(self) -> float | None:
        dt = brainstate.environ.get('dt', default=None)
        if dt is None:
            return None
        return self._to_scalar_time_ms(dt)

    def _current_time_ms(self) -> float:
        t = brainstate.environ.get('t', default=None)
        if t is None:
            return 0.0
        # Fast path for scalar Quantity (avoids np.asarray round-trip).
        if isinstance(t, u.Quantity):
            return float(t.to_decimal(u.ms))
        return float(t)

    def _refresh_runtime_cache(self, dt_ms: float):
        self._assert_grid_time('origin', self.origin, dt_ms)
        self._assert_grid_time('start', self.start, dt_ms)
        self._assert_grid_time('stop', self.stop, dt_ms)

        self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
        if np.isfinite(self.stop):
            self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
        else:
            self._t_max_step = np.iinfo(np.int64).max

        self._num_age_bins = int(self.dead_time / dt_ms)
        self._omega_rad_per_ms = 2.0 * math.pi * self.frequency / 1000.0
        if self.rate > 0.0:
            self._hazard_step = dt_ms / (1000.0 / self.rate - self.dead_time)
        else:
            self._hazard_step = 0.0
        self._dt_cache_ms = float(dt_ms)

    def _is_active(self, curr_step: int) -> bool:
        return (self._t_min_step < curr_step) and (curr_step <= self._t_max_step)

[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Initialize occupancy arrays and NumPy RNG for all output trains. Allocates three :class:`brainstate.ShortTermState` arrays representing the age-discretized occupation model and seeds the NumPy random generator. The initial occupancy follows NEST's ``pre_run_hook()`` logic: ``floor(rate / 1000 * n_proc * dt)`` processes are placed in each refractory age bin, and the remainder fills ``occ_active``. Parameters ---------- batch_size : int or None, optional Unused API placeholder for compatibility with the :class:`brainstate.nn.Dynamics` interface. Ignored. **kwargs Additional unused keyword arguments accepted for API compatibility. Ignored. Notes ----- If ``dt`` is not available in the simulation environment at call time, ``dt_ms`` defaults to ``0.0`` so that ``ini_occ_ref == 0`` and all ``n_proc`` processes start in ``occ_active``. The refractory array is still allocated with the correct number of age bins computed from any previously cached ``_num_age_bins`` value, which may also be zero if no ``dt`` context was ever set. """ del batch_size, kwargs dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_runtime_cache(dt_ms) else: dt_ms = 0.0 ini_occ_ref = int(self.rate / 1000.0 * self.n_proc * dt_ms) ini_occ_act = int(self.n_proc - ini_occ_ref * self._num_age_bins) ditype = brainstate.environ.ditype() self.occ_refractory = brainstate.ShortTermState( np.full( (self._num_targets, self._num_age_bins), ini_occ_ref, dtype=ditype, ) ) self.occ_active = brainstate.ShortTermState( np.full(self._num_targets, ini_occ_act, dtype=ditype) ) self.activate = brainstate.ShortTermState( np.zeros(self._num_targets, dtype=ditype) ) self._rng = np.random.default_rng(self.rng_seed)
[docs] def set( self, *, rate: ArrayLike | object = _UNSET, dead_time: ArrayLike | object = _UNSET, n_proc: ArrayLike | object = _UNSET, frequency: ArrayLike | object = _UNSET, relative_amplitude: ArrayLike | object = _UNSET, start: ArrayLike | object = _UNSET, stop: ArrayLike | object = _UNSET, origin: ArrayLike | object = _UNSET, ): r"""Update public generator parameters with NEST-compatible semantics. Any parameter left at the internal sentinel ``_UNSET`` retains its current value. All supplied values are converted and cross-validated before any attribute is mutated, so the generator state remains self-consistent on failure. If ``dt`` is currently available in ``brainstate.environ``, the cached hazard step, angular frequency, number of age bins, and timing step bounds are recomputed immediately after mutation. Parameters ---------- rate : ArrayLike or object, optional New scalar component-process rate in Hz. If omitted, keep the current value. Must satisfy ``1000 / rate > dead_time`` for ``rate > 0`` after scalar conversion. dead_time : ArrayLike or object, optional New scalar absolute refractory duration in ms. If omitted, keep current value. Must be ``>= 0`` and satisfy ``1000 / rate > dead_time`` for nonzero ``rate``. n_proc : ArrayLike or object, optional New scalar integer number of component processes ``>= 1``. If omitted, keep current value. frequency : ArrayLike or object, optional New scalar sinusoidal modulation frequency in Hz. ``0`` disables modulation even when ``relative_amplitude > 0``. If omitted, keep current value. relative_amplitude : ArrayLike or object, optional New scalar dimensionless modulation amplitude in ``[0, 1]``. If omitted, keep current value. start : ArrayLike or object, optional New scalar relative start time in ms (exclusive lower bound). If omitted, keep current value. stop : ArrayLike, None, or object, optional New scalar relative stop time in ms (inclusive upper bound). ``None`` maps to ``+inf`` (unbounded). If omitted, keep current value. origin : ArrayLike or object, optional New scalar time-origin offset in ms. If omitted, keep current value. Raises ------ ValueError If any provided value is non-scalar; if ``dead_time < 0``; if ``n_proc < 1``; if ``relative_amplitude`` is outside ``[0, 1]``; if ``stop < start``; if ``1000 / rate <= dead_time`` for nonzero ``rate``; if integer inputs are non-integral beyond tolerance; or if finite timing values are off the simulation grid when ``dt`` is available. TypeError If unit conversion to ``u.Hz`` or ``u.ms`` fails for supplied inputs. """ new_dead_time = ( self.dead_time if dead_time is _UNSET else self._to_scalar_time_ms(dead_time) ) new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate) new_n_proc = ( self.n_proc if n_proc is _UNSET else self._to_scalar_int(n_proc, name='n_proc') ) new_frequency = ( self.frequency if frequency is _UNSET else self._to_scalar_rate_hz(frequency) ) new_relative_amplitude = ( self.relative_amplitude if relative_amplitude is _UNSET else self._to_scalar_float(relative_amplitude, name='relative_amplitude') ) new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start) if stop is _UNSET: new_stop = self.stop elif stop is None: new_stop = np.inf else: new_stop = self._to_scalar_time_ms(stop) new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin) self._validate_parameters( rate=new_rate, dead_time=new_dead_time, n_proc=new_n_proc, relative_amplitude=new_relative_amplitude, start=new_start, stop=new_stop, ) self.dead_time = new_dead_time self.rate = new_rate self.n_proc = new_n_proc self.frequency = new_frequency self.relative_amplitude = new_relative_amplitude self.start = new_start self.stop = new_stop self.origin = new_origin dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_runtime_cache(dt_ms)
[docs] def get(self) -> dict: r"""Return current public parameters as plain Python scalars. Returns ------- out : dict ``dict`` with the following keys and value types: - ``'rate'`` — ``float``, component-process rate in Hz. - ``'dead_time'`` — ``float``, absolute refractory duration in ms. - ``'n_proc'`` — ``int``, number of component processes. - ``'frequency'`` — ``float``, sinusoidal modulation frequency in Hz. - ``'relative_amplitude'`` — ``float``, modulation depth in ``[0, 1]``. - ``'start'`` — ``float``, relative exclusive lower activity bound in ms. - ``'stop'`` — ``float``, relative inclusive upper activity bound in ms; ``math.inf`` when the generator was constructed or set with ``stop=None``. - ``'origin'`` — ``float``, time-origin offset in ms. """ return { 'rate': float(self.rate), 'dead_time': float(self.dead_time), 'n_proc': int(self.n_proc), 'frequency': float(self.frequency), 'relative_amplitude': float(self.relative_amplitude), 'start': float(self.start), 'stop': float(self.stop), 'origin': float(self.origin), }
def _sample_poisson(self, lam: float) -> int: return int(self._rng.poisson(lam)) def _sample_binomial(self, n: int, p: float) -> int: # Clamp only for numerical safety around invalid domain boundaries. if p <= 0.0: return 0 if p >= 1.0: return int(n) return int(self._rng.binomial(n, p)) def _update_age_distribution_single( self, occ_ref_row: np.ndarray, occ_active: int, activate_idx: int, hazard_step_t: float, ) -> tuple[int, int, int]: if occ_active > 0: use_poisson_approx = ( (occ_active >= 100 and hazard_step_t <= 0.01) or (occ_active >= 500 and hazard_step_t * occ_active <= 0.1) ) if use_poisson_approx: n_spikes = self._sample_poisson(hazard_step_t * occ_active) if n_spikes > occ_active: n_spikes = occ_active else: n_spikes = self._sample_binomial(occ_active, hazard_step_t) else: n_spikes = 0 if occ_ref_row.size > 0: occ_active = int(occ_active + occ_ref_row[activate_idx] - n_spikes) occ_ref_row[activate_idx] = n_spikes activate_idx = int((activate_idx + 1) % occ_ref_row.size) return int(n_spikes), int(occ_active), int(activate_idx)
[docs] def update(self): r"""Advance one simulation step and return per-train spike multiplicity. Lazily initializes state on the first call, refreshes the runtime cache when ``dt`` changes, applies the active-window test, computes an instantaneous hazard (with optional sinusoidal modulation), and updates each output train's age-discretized occupation model using NEST's branch logic (binomial or Poisson approximation). The method mirrors NEST's ``ppd_sup_generator::update`` procedure: 1. Ensure internal state is initialized; refresh cache if ``dt`` changed since the last call. 2. Return zeros immediately when ``rate <= 0`` or no output trains exist. 3. Evaluate the active-window guard: :math:`t_{\min} < t \le t_{\max}`. 4. Compute the per-step hazard: .. math:: h_t = h_{\mathrm{step}} \left(1 + A \sin(2\pi f\, t / 1000)\right), skipping the sinusoidal term when ``relative_amplitude == 0`` or ``frequency == 0``. 5. For each output train, call :meth:`_update_age_distribution_single` which draws ``n_spikes`` from the active pool and rotates the refractory ring buffer. Returns ------- out : jax.Array JAX array of dtype ``int64`` and shape ``self.varshape``. Each element is the number of spikes emitted by the corresponding output train during the current step. Returns all-zeros when inactive, when ``rate <= 0``, or when no targets are defined. Raises ------ KeyError If required simulation-context fields (for example ``dt`` via ``brainstate.environ.get_dt()``) are unavailable. ValueError If finite timing parameters are not on the simulation grid after a ``dt`` change triggers cache refresh. TypeError If simulation-time values in the environment cannot be converted to scalar milliseconds. """ if not hasattr(self, 'occ_refractory'): self.init_state() if not np.isfinite(self._dt_cache_ms): self._refresh_runtime_cache(self._dt_ms()) dt_ms = self._dt_cache_ms ditype = brainstate.environ.ditype() if self.rate <= 0.0 or self._num_targets == 0: return np.zeros(self.varshape, dtype=ditype) curr_t_ms = self._current_time_ms() curr_step = self._time_to_step(curr_t_ms, dt_ms) if not self._is_active(curr_step): return np.zeros(self.varshape, dtype=ditype) if self.relative_amplitude > 0.0 and self.frequency != 0.0: hazard_step_t = self._hazard_step * ( 1.0 + self.relative_amplitude * math.sin(self._omega_rad_per_ms * curr_t_ms) ) if hazard_step_t < 0.0 and hazard_step_t > -1e-15: hazard_step_t = 0.0 else: hazard_step_t = self._hazard_step occ_ref = np.asarray(self.occ_refractory.value, dtype=ditype).copy() occ_active = np.asarray(self.occ_active.value, dtype=ditype).copy() activate = np.asarray(self.activate.value, dtype=ditype).copy() counts = np.zeros(self._num_targets, dtype=ditype) for idx in range(self._num_targets): n_spikes, occ_act_i, activate_i = self._update_age_distribution_single( occ_ref[idx], int(occ_active[idx]), int(activate[idx]), hazard_step_t, ) counts[idx] = n_spikes occ_active[idx] = occ_act_i activate[idx] = activate_i self.occ_refractory.value = occ_ref self.occ_active.value = occ_active self.activate.value = activate return counts.reshape(self.varshape)