Source code for brainpy_state._nest.gamma_sup_generator

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


import math

import brainstate
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'gamma_sup_generator',
]

_UNSET = object()


class gamma_sup_generator(NESTDevice):
    r"""Superposition of independent gamma processes (NEST-compatible).

    Description
    -----------
    ``gamma_sup_generator`` re-implements NEST's stimulation device of the
    same name. It emits, per output train and simulation step, the
    multiplicity of spikes produced by superimposing ``n_proc`` independent
    component renewal processes with gamma-distributed inter-spike intervals.

    **1. State-space model, derivation, and update equations**

    Let :math:`k = \mathrm{gamma\_shape}` and :math:`r = \mathrm{rate}` in Hz.
    Each component gamma process is represented as a cyclic chain of ``k``
    exponential phases. Using an occupation vector per output train,

    .. math::

       \mathbf{occ} = (occ_0, \dots, occ_{k-1}),
       \qquad \sum_{i=0}^{k-1} occ_i = n_{\mathrm{proc}},

    a process in phase ``i`` transitions to phase ``i+1`` (mod ``k``) with
    per-step probability

    .. math::

       p = r \cdot k \cdot \Delta t / 1000.

    This is the discrete-time hazard form used by NEST. For each bin ``i``,

    .. math::

       n_i \sim \mathrm{Binomial}(occ_i, p),

    except for NEST's sparse/high-count approximation branch:

    - if ``occ_i >= 100 and p <= 0.01``, or
    - if ``occ_i >= 500 and p * occ_i <= 0.1``,

    use ``Poisson(p * occ_i)`` and clip to ``occ_i``.

    After sampling, all ``n_i`` are moved simultaneously to preserve integer
    mass and avoid order-dependent updates. The emitted spike multiplicity for
    one train is

    .. math::

       K_n = n_{k-1},

    i.e., transitions leaving the last phase and re-entering phase 0. This
    allows per-step counts larger than 1, matching NEST ``SpikeEvent``
    multiplicity semantics.

    **2. Timing semantics and activity window**

    Activity follows NEST ``StimulationDevice::is_active`` for spike
    generators:

    .. math::

       t_{\min} < t \le t_{\max},
       \qquad
       t_{\min} = origin + start,\quad t_{\max} = origin + stop.

    Therefore ``start`` is exclusive and ``stop`` is inclusive. Internally,
    finite times are projected to integer steps with
    ``round(time_ms / dt_ms)`` and checked as
    ``t_min_step < curr_step <= t_max_step``.

    **3. Assumptions, constraints, and computational implications**

    Parameters are scalarized to ``float64``/``int`` before simulation.
    Enforced constraints are ``rate >= 0``, ``gamma_shape >= 1``,
    ``n_proc >= 1``, and ``stop >= start``. If ``dt`` is available, finite
    ``origin``, ``start``, and ``stop`` must lie on the simulation grid
    (absolute tolerance ``1e-12`` in ``time/dt`` ratio).

    Runtime complexity of :meth:`update` is
    :math:`O(\prod \mathrm{varshape} \cdot \mathrm{gamma\_shape})`, with state
    memory :math:`O(\prod \mathrm{varshape} \cdot \mathrm{gamma\_shape})` from
    the occupation matrix. RNG sampling uses ``numpy.random.Generator``
    (seeded by ``rng_seed``), so stochastic draws are CPU NumPy based rather
    than JAX-key based.

    Parameters
    ----------
    in_size : Size, optional
        Output size specification consumed by
        :class:`brainstate.nn.Dynamics`. The exact shape returned by
        :meth:`update` is ``self.varshape`` derived from ``in_size``; each
        element corresponds to one independent output train. Default is ``1``.
    rate : ArrayLike, optional
        Scalar component-process rate in spikes/s (Hz), shape ``()`` after
        conversion. Accepts a single-element numeric ``ArrayLike`` or a
        :class:`saiunit.Quantity` convertible to ``u.Hz``.
        Must satisfy ``rate >= 0``. Default is ``0.0 * u.Hz``.
    gamma_shape : ArrayLike, optional
        Scalar integer gamma shape :math:`k`, shape ``()`` after conversion.
        Parsed via nearest-integer check with absolute tolerance ``1e-12``.
        Must satisfy ``gamma_shape >= 1``. Default is ``1``.
    n_proc : ArrayLike, optional
        Scalar integer number of independent component processes per output
        train, shape ``()`` after conversion. Parsed by nearest-integer check
        with absolute tolerance ``1e-12``. Must satisfy ``n_proc >= 1``.
        Default is ``1``.
    start : ArrayLike, optional
        Scalar relative activation time in ms, shape ``()`` after conversion.
        Effective lower activity bound is ``origin + start`` and is exclusive.
        Must be grid-representable when ``dt`` is available.
        Default is ``0.0 * u.ms``.
    stop : ArrayLike or None, optional
        Scalar relative deactivation time in ms, shape ``()`` after
        conversion. Effective upper activity bound is ``origin + stop`` and is
        inclusive. ``None`` maps to ``+inf``. Must satisfy ``stop >= start``
        and be grid-representable when finite and ``dt`` is available.
        Default is ``None``.
    origin : ArrayLike, optional
        Scalar time-origin offset in ms, shape ``()`` after conversion, added
        to ``start`` and ``stop`` to compute absolute active bounds.
        Must be grid-representable when finite and ``dt`` is available.
        Default is ``0.0 * u.ms``.
    rng_seed : int, optional
        Seed used to initialize ``numpy.random.default_rng`` in
        :meth:`init_state`. Default is ``0``.
    name : str or None, optional
        Optional node name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 20 40

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``rate``
         - ``0.0 * u.Hz``
         - :math:`r`
         - Component-process rate in spikes/s.
       * - ``gamma_shape``
         - ``1``
         - :math:`k`
         - Number of cyclic exponential phases per component process.
       * - ``n_proc``
         - ``1``
         - :math:`n_{\mathrm{proc}}`
         - Number of superimposed component processes per output train.
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower bound of activity.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound; ``None`` maps to ``+\infty``.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Global offset added to ``start`` and ``stop``.
       * - ``in_size``
         - ``1``
         - -
         - Defines ``self.varshape`` for independent output trains.
       * - ``rng_seed``
         - ``0``
         - -
         - Seed of the NumPy generator used for transition draws.

    Raises
    ------
    ValueError
        If scalar conversion fails due to non-scalar inputs; if ``rate < 0``;
        if ``gamma_shape < 1``; if ``n_proc < 1``; if ``stop < start``; if
        integer-valued inputs are non-integral beyond tolerance; or if finite
        ``origin``/``start``/``stop`` are not multiples of simulation
        resolution when ``dt`` is available.
    TypeError
        If unit conversion to ``u.Hz`` or ``u.ms`` fails for supplied inputs.
    KeyError
        At runtime, if required simulation-context fields (for example ``dt``
        used by ``brainstate.environ.get_dt()``) are unavailable.

    Notes
    -----
    - Initial occupation is the NEST equilibrium approximation used in
      ``pre_run_hook()``:
      ``floor(n_proc / gamma_shape)`` in all bins, with the remainder added
      to the last bin.
    - As in NEST, each output train maintains independent internal occupation
      states.
    - In the direct binomial branch, ``transition_prob`` is numerically
      clamped to ``[0, 1]`` before calling NumPy RNG to avoid invalid
      probability inputs in edge cases.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     gen = brainpy.state.gamma_sup_generator(
       ...         in_size=(2, 3),
       ...         rate=20.0 * u.Hz,
       ...         gamma_shape=3,
       ...         n_proc=50,
       ...         start=5.0 * u.ms,
       ...         stop=40.0 * u.ms,
       ...         rng_seed=7,
       ...     )
       ...     with brainstate.environ.context(t=12.0 * u.ms):
       ...         counts = gen.update()
       ...     _ = counts.shape

    .. code-block:: python

       >>> import brainpy
       >>> import saiunit as u
       >>> gen = brainpy.state.gamma_sup_generator(rate=15.0 * u.Hz, gamma_shape=2)
       >>> gen.set(n_proc=20, stop=None, origin=1.0 * u.ms)
       >>> params = gen.get()
       >>> _ = params['gamma_shape'], params['n_proc']

    See Also
    --------
    ppd_sup_generator : Superposition with dead-time component processes.
    sinusoidal_gamma_generator : Inhomogeneous gamma generator with sinusoidal rate.

    References
    ----------
    .. [1] NEST source: ``models/gamma_sup_generator.cpp`` and
           ``models/gamma_sup_generator.h``.
    .. [2] NEST model docs:
           https://nest-simulator.readthedocs.io/en/stable/models/gamma_sup_generator.html
    .. [3] Deger M, Helias M, Boucsein C, Rotter S (2012).
           Statistical properties of superimposed stationary spike trains.
           Journal of Computational Neuroscience.
           https://doi.org/10.1007/s10827-011-0362-8
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        rate: ArrayLike = 0. * u.Hz,
        gamma_shape: ArrayLike = 1,
        n_proc: ArrayLike = 1,
        start: ArrayLike = 0. * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0. * u.ms,
        rng_seed: int = 0,
        name: str | None = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.rate = self._to_scalar_rate_hz(rate)
        self.gamma_shape = self._to_scalar_int(gamma_shape, name='gamma_shape')
        self.n_proc = self._to_scalar_int(n_proc, name='n_proc')
        self.start = self._to_scalar_time_ms(start)
        self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
        self.origin = self._to_scalar_time_ms(origin)
        self.rng_seed = int(rng_seed)

        self._validate_parameters(
            rate=self.rate,
            gamma_shape=self.gamma_shape,
            n_proc=self.n_proc,
            start=self.start,
            stop=self.stop,
        )

        self._num_targets = int(np.prod(self.varshape))
        self._transition_prob = 0.0
        self._dt_cache_ms = np.nan
        self._t_min_step = 0
        self._t_max_step = np.iinfo(np.int64).max
        dt_ms = self._maybe_dt_ms()
        if dt_ms is not None:
            self._refresh_runtime_cache(dt_ms)

    @staticmethod
    def _to_scalar_time_ms(value: ArrayLike) -> float:
        if isinstance(value, u.Quantity):
            dftype = brainstate.environ.dftype()
            arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('Time parameters must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_rate_hz(value: ArrayLike) -> float:
        if isinstance(value, u.Quantity):
            dftype = brainstate.environ.dftype()
            arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
        else:
            arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError('rate must be scalar.')
        return float(arr.reshape(()))

    @staticmethod
    def _to_scalar_int(value: ArrayLike, *, name: str) -> int:
        dftype = brainstate.environ.dftype()
        arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
        if arr.size != 1:
            raise ValueError(f'{name} must be scalar.')
        scalar = float(arr.reshape(()))
        nearest = np.rint(scalar)
        if not math.isclose(scalar, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be an integer.')
        return int(nearest)

    @staticmethod
    def _validate_parameters(
        *,
        rate: float,
        gamma_shape: int,
        n_proc: int,
        start: float,
        stop: float,
    ):
        if gamma_shape < 1:
            raise ValueError('The shape must be larger or equal 1')
        if rate < 0.0:
            raise ValueError('The rate must be larger than 0.')
        if n_proc < 1:
            raise ValueError('The number of component processes cannot be smaller than one')
        if stop < start:
            raise ValueError('stop >= start required.')

    @staticmethod
    def _time_to_step(time_ms: float, dt_ms: float) -> int:
        return int(np.rint(time_ms / dt_ms))

    @staticmethod
    def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
        if not np.isfinite(time_ms):
            return
        ratio = time_ms / dt_ms
        nearest = np.rint(ratio)
        if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
            raise ValueError(f'{name} must be a multiple of the simulation resolution.')

    def _dt_ms(self) -> float:
        dt = brainstate.environ.get_dt()
        return self._to_scalar_time_ms(dt)

    def _maybe_dt_ms(self) -> float | None:
        dt = brainstate.environ.get('dt', default=None)
        if dt is None:
            return None
        return self._to_scalar_time_ms(dt)

    def _current_time_ms(self) -> float:
        t = brainstate.environ.get('t', default=None)
        if t is None:
            return 0.0
        # Fast path for scalar Quantity (avoids np.asarray round-trip).
        if isinstance(t, u.Quantity):
            return float(t.to_decimal(u.ms))
        return float(t)

    def _refresh_runtime_cache(self, dt_ms: float):
        self._assert_grid_time('origin', self.origin, dt_ms)
        self._assert_grid_time('start', self.start, dt_ms)
        self._assert_grid_time('stop', self.stop, dt_ms)

        self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
        if np.isfinite(self.stop):
            self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
        else:
            self._t_max_step = np.iinfo(np.int64).max

        self._transition_prob = self.rate * self.gamma_shape * dt_ms / 1000.0
        self._dt_cache_ms = float(dt_ms)

    def _is_active(self, curr_step: int) -> bool:
        return (self._t_min_step < curr_step) and (curr_step <= self._t_max_step)

[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Initialize occupancy and RNG state for all output trains. Parameters ---------- batch_size : int or None, optional Unused API placeholder for compatibility with Dynamics interfaces. Ignored. **kwargs Additional unused keyword arguments. Ignored. """ del batch_size, kwargs ini_occ_ref = int(self.n_proc // self.gamma_shape) ini_occ_act = int(self.n_proc - ini_occ_ref * self.gamma_shape) ditype = brainstate.environ.ditype() occ = np.full( (self._num_targets, self.gamma_shape), ini_occ_ref, dtype=ditype, ) occ[:, -1] += ini_occ_act self.occ = brainstate.ShortTermState(occ) self._rng = np.random.default_rng(self.rng_seed)
[docs] def set( self, *, rate: ArrayLike | object = _UNSET, gamma_shape: ArrayLike | object = _UNSET, n_proc: ArrayLike | object = _UNSET, start: ArrayLike | object = _UNSET, stop: ArrayLike | object = _UNSET, origin: ArrayLike | object = _UNSET, ): r"""Set public parameters with NEST-compatible semantics. Parameters ---------- rate : ArrayLike or object, optional New scalar component rate in Hz. ``_UNSET`` keeps current value. gamma_shape : ArrayLike or object, optional New scalar integer gamma shape ``>= 1``. ``_UNSET`` keeps current value. n_proc : ArrayLike or object, optional New scalar integer number of component processes ``>= 1``. ``_UNSET`` keeps current value. start : ArrayLike or object, optional New scalar relative start time in ms. ``_UNSET`` keeps current value. stop : ArrayLike, object, or None, optional New scalar relative stop time in ms; ``None`` maps to ``+inf``. ``_UNSET`` keeps current value. origin : ArrayLike or object, optional New scalar origin time in ms. ``_UNSET`` keeps current value. Raises ------ ValueError If converted values violate model constraints or are non-scalar. TypeError If unit conversion to Hz/ms fails. """ new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate) new_gamma_shape = ( self.gamma_shape if gamma_shape is _UNSET else self._to_scalar_int(gamma_shape, name='gamma_shape') ) new_n_proc = ( self.n_proc if n_proc is _UNSET else self._to_scalar_int(n_proc, name='n_proc') ) new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start) if stop is _UNSET: new_stop = self.stop elif stop is None: new_stop = np.inf else: new_stop = self._to_scalar_time_ms(stop) new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin) self._validate_parameters( rate=new_rate, gamma_shape=new_gamma_shape, n_proc=new_n_proc, start=new_start, stop=new_stop, ) self.rate = new_rate self.gamma_shape = new_gamma_shape self.n_proc = new_n_proc self.start = new_start self.stop = new_stop self.origin = new_origin dt_ms = self._maybe_dt_ms() if dt_ms is not None: self._refresh_runtime_cache(dt_ms)
[docs] def get(self) -> dict: r"""Return current public parameters as plain Python scalars. Returns ------- out : dict ``dict`` with keys ``rate``, ``gamma_shape``, ``n_proc``, ``start``, ``stop``, and ``origin``. Values are ``float``/``int`` in public units (Hz for ``rate``, ms for times). """ return { 'rate': float(self.rate), 'gamma_shape': int(self.gamma_shape), 'n_proc': int(self.n_proc), 'start': float(self.start), 'stop': float(self.stop), 'origin': float(self.origin), }
def _sample_poisson(self, lam: float) -> int: return int(self._rng.poisson(lam)) def _sample_binomial(self, n: int, p: float) -> int: return int(self._rng.binomial(n, p)) def _update_internal_states(self, occ_row: np.ndarray, transition_prob: float, ditype=None) -> int: n_bins = occ_row.size if ditype is None: ditype = brainstate.environ.ditype() n_trans = np.zeros(n_bins, dtype=ditype) for i in range(n_bins): occ_i = int(occ_row[i]) if occ_i <= 0: continue use_poisson_approx = ( (occ_i >= 100 and transition_prob <= 0.01) or (occ_i >= 500 and transition_prob * occ_i <= 0.1) ) if use_poisson_approx: n_i = self._sample_poisson(transition_prob * occ_i) if n_i > occ_i: n_i = occ_i else: # NEST uses std::binomial_distribution directly. # Clamp p numerically to avoid invalid values in Python RNG. if transition_prob <= 0.0: n_i = 0 elif transition_prob >= 1.0: n_i = occ_i else: n_i = self._sample_binomial(occ_i, transition_prob) n_trans[i] = int(n_i) # Vectorized occupation update: subtract from source bins, # add to destination bins (cyclic: last phase wraps to phase 0). occ_row -= n_trans occ_row[1:] += n_trans[:-1] occ_row[0] += int(n_trans[-1]) return int(n_trans[-1])
[docs] def update(self): r"""Advance one simulation step and return per-train spike multiplicity. The method lazily initializes state, refreshes timing/probability cache when ``dt`` changes, applies the active-window test, then updates each train's occupation vector using NEST-equivalent transition logic. Returns ------- out : jax.Array JAX array with dtype ``int64`` and shape ``self.varshape``. Each element is the number of emitted spikes for one output train in the current step. Raises ------ ValueError If cached times fail simulation-grid consistency checks during cache refresh. KeyError If required simulation context values (for example ``dt``) are unavailable in ``brainstate.environ``. """ if not hasattr(self, 'occ'): self.init_state() if not np.isfinite(self._dt_cache_ms): self._refresh_runtime_cache(self._dt_ms()) dt_ms = self._dt_cache_ms ditype = brainstate.environ.ditype() if self.rate <= 0.0 or self._num_targets == 0: return jnp.zeros(self.varshape, dtype=ditype) curr_step = self._time_to_step(self._current_time_ms(), dt_ms) if not self._is_active(curr_step): return jnp.zeros(self.varshape, dtype=ditype) occ = np.asarray(self.occ.value, dtype=ditype).copy() counts = np.zeros(self._num_targets, dtype=ditype) for idx in range(self._num_targets): counts[idx] = self._update_internal_states(occ[idx], self._transition_prob, ditype) self.occ.value = occ return jnp.array(counts.reshape(self.varshape), dtype=ditype)