# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
import brainstate
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'sinusoidal_gamma_generator',
]
_UNSET = object()
class sinusoidal_gamma_generator(NESTDevice):
r"""Sinusoidally modulated gamma spike generator compatible with NEST.
Description
-----------
``sinusoidal_gamma_generator`` re-implements NEST's stimulation device of
the same name. It emits binary spikes from an inhomogeneous gamma renewal
process whose instantaneous rate is sinusoidally modulated.
**1. Instantaneous-rate model**
The internal rate in spikes/ms is
.. math::
\lambda(t) = r + a \sin(\omega t + \phi),
with parameter-to-symbol conversion:
- :math:`r = \mathrm{rate}/1000`,
- :math:`a = \mathrm{amplitude}/1000`,
- :math:`\omega = 2\pi \cdot \mathrm{frequency}/1000` (rad/ms),
- :math:`\phi = \mathrm{phase}\cdot\pi/180` (rad).
The validated constraint ``0 <= amplitude <= rate`` guarantees
:math:`\lambda(t) \ge 0` for all :math:`t`.
**2. Renewal integral, closed-form increment, and hazard**
For gamma order :math:`k = \mathrm{order}` and train-specific renewal
origin :math:`t_0`, define the scaled integrated hazard as
.. math::
\Lambda(t) = k \int_{t_0}^{t} \lambda(s)\,ds.
The implementation keeps ``t0_ms`` and ``Lambda_t0`` as per-train state
variables and advances :math:`\Lambda` each step via the closed-form
increment computed in :meth:`_delta_lambda`:
.. math::
\Delta\Lambda = k r (t_b - t_a)
- \frac{k a}{\omega}\Bigl[
\cos(\omega t_b + \phi) - \cos(\omega t_a + \phi)
\Bigr].
When ``amplitude == 0`` or ``frequency == 0`` (i.e. :math:`\omega = 0`),
the cosine term is omitted and :math:`\Delta\Lambda = k r (t_b - t_a)`,
which avoids division by zero and recovers the homogeneous Poisson limit
(:math:`k = 1`) or homogeneous gamma limit (:math:`k > 1`).
The per-step hazard (already multiplied by ``dt``) evaluated at time
:math:`t` is
.. math::
h(t) = \Delta t \cdot
\frac{k\,\lambda(t)\,\Lambda(t)^{k-1}\,e^{-\Lambda(t)}}
{\Gamma(k,\,\Lambda(t))},
where :math:`\Gamma(k, \Lambda)` is the upper incomplete gamma function
evaluated via ``jax.lax.igammac`` and ``math.gamma``. The ratio
:math:`h(t)` approximates :math:`\Pr(\text{spike in } [t, t+\Delta t))`
under the gamma renewal model.
**3. Update ordering and activity-window semantics**
Each call to :meth:`update` mirrors the ordering in NEST
``models/sinusoidal_gamma_generator.cpp``:
1. Evaluate time at the right edge of the current step:
``t_eval = (step + 1) * dt``.
2. Compute :math:`\lambda(t_{\mathrm{eval}})` and cache the value as the
step-end instantaneous rate in spikes/s (accessible via
:meth:`get_recorded_rate`).
3. If the generator is active and :math:`\lambda(t_{\mathrm{eval}}) > 0`,
compute the per-train hazard and sample Bernoulli decisions.
4. Reset ``t0_ms`` and ``Lambda_t0`` to ``t_eval`` / ``0`` for any train
that fired.
5. Return binary spike outputs as ``int64`` with shape ``self.varshape``.
NEST spike-generator activity semantics use the half-open-right window
.. math::
t_{\min} < n \le t_{\max},
where :math:`n` is the current integer step index,
``t_min = round((origin + start) / dt)``, and
``t_max = round((origin + stop) / dt)`` after projection to grid steps.
**4. Piecewise-integral semantics on parameter changes**
When :meth:`set` is called after initialization, the existing per-train
renewal state is first advanced to the change time :math:`t_c` using the
*previous* process parameters, then future increments use the *new*
parameters:
.. math::
\Lambda(t) = \Lambda_{\mathrm{old}}(t_c)
+ k_{\mathrm{new}} \int_{t_c}^{t} \lambda_{\mathrm{new}}(s)\,ds.
This preserves renewal history across parameter updates, matching NEST
``SetStatus`` behavior.
**5. Assumptions, constraints, and computational implications**
- Public parameters are scalarized to ``float64`` (or ``int`` for
``rng_seed``); non-scalar inputs raise :class:`ValueError`.
- Enforced constraints: ``order >= 1``, ``0 <= amplitude <= rate``,
and ``stop >= start``.
- When ``dt`` is available at construction time, finite
``origin`` / ``start`` / ``stop`` must be representable on the
simulation grid (absolute tolerance ``1e-12`` in ``time / dt`` ratio).
- ``individual_spike_trains=True`` allocates one independent renewal
state per element of ``self.varshape`` and draws independent Bernoulli
samples; ``individual_spike_trains=False`` maintains one shared renewal
state and broadcasts the single Bernoulli draw to all outputs.
- Per-step runtime is :math:`O(n_{\mathrm{trains}})` for hazard
evaluation and sampling, with memory
:math:`O(n_{\mathrm{trains}})` for ``t0_ms`` and ``Lambda_t0``.
- At most one spike per train can be emitted per step because spike
decisions are independent Bernoulli trials against the per-step hazard.
Parameters
----------
in_size : Size, optional
Output size specification for :class:`brainstate.nn.Dynamics`.
``self.varshape`` derived from ``in_size`` is the exact output shape
of :meth:`update`; each element corresponds to one emitted train.
Default is ``1``.
rate : ArrayLike, optional
Scalar mean firing rate in spikes/s (Hz), shape ``()`` after
conversion. Accepted as a scalar ``ArrayLike`` or a
:class:`saiunit.Quantity` convertible to ``u.Hz``.
Must satisfy ``0 <= amplitude <= rate``.
Default is ``0.0 * u.Hz``.
amplitude : ArrayLike, optional
Scalar sinusoidal modulation amplitude in spikes/s (Hz), shape ``()``
after conversion. Must satisfy ``0 <= amplitude <= rate`` after
conversion; the constraint ensures :math:`\lambda(t) \ge 0`.
Default is ``0.0 * u.Hz``.
frequency : ArrayLike, optional
Scalar modulation frequency in Hz, shape ``()`` after conversion.
Internally mapped to angular frequency
:math:`\omega = 2\pi f / 1000` (rad/ms).
Default is ``0.0 * u.Hz``.
phase : ArrayLike, optional
Scalar initial phase in degrees, shape ``()`` after conversion;
internally converted to radians as :math:`\phi = \phi_{\deg} \pi / 180`.
Default is ``0.0``.
order : ArrayLike, optional
Scalar gamma renewal order :math:`k`, shape ``()`` after conversion.
Must satisfy ``order >= 1``; order ``1`` recovers an inhomogeneous
Poisson process.
Default is ``1.0``.
individual_spike_trains : bool, optional
Spike-generation mode selector.
If ``True``, each output index in ``self.varshape`` maintains
independent renewal state ``(t0_ms, Lambda_t0)`` and receives
independent Bernoulli draws.
If ``False``, one shared renewal process is maintained and the same
binary spike decision is broadcast to all outputs.
Default is ``True``.
start : ArrayLike, optional
Scalar relative activation start time in ms, shape ``()`` after
conversion. Effective lower activity bound is ``origin + start``;
the bound is exclusive in step space (``t_min_step < curr_step``).
Default is ``0.0 * u.ms``.
stop : ArrayLike or None, optional
Scalar relative deactivation stop time in ms, shape ``()`` after
conversion. ``None`` maps to ``+inf``. Effective upper activity bound
is ``origin + stop``; the bound is inclusive in step space
(``curr_step <= t_max_step``). Must satisfy ``stop >= start`` after
conversion.
Default is ``None``.
origin : ArrayLike, optional
Scalar origin offset in ms, shape ``()`` after conversion. Added to
``start`` and ``stop`` to obtain the absolute activity bounds
``t_min`` and ``t_max``.
Default is ``0.0 * u.ms``.
rng_seed : int, optional
Seed used to initialize ``jax.random.PRNGKey`` during
:meth:`init_state` and lazy initialization in :meth:`update`.
Default is ``0``.
name : str or None, optional
Optional node name passed to :class:`brainstate.nn.Dynamics`.
Default is ``None``.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 22 18 20 40
* - Parameter
- Default
- Math symbol
- Semantics
* - ``rate``
- ``0.0 * u.Hz``
- :math:`r`
- Baseline firing-rate term in spikes/ms after division by ``1000``.
* - ``amplitude``
- ``0.0 * u.Hz``
- :math:`a`
- Sinusoidal modulation amplitude in spikes/ms after division by ``1000``.
* - ``frequency``
- ``0.0 * u.Hz``
- :math:`f`
- Frequency in Hz mapped to :math:`\omega = 2\pi f/1000` (rad/ms).
* - ``phase``
- ``0.0``
- :math:`\phi`
- Phase in degrees mapped to radians as :math:`\phi_{\deg}\pi/180`.
* - ``order``
- ``1.0``
- :math:`k`
- Gamma renewal order; ``1`` = inhomogeneous Poisson.
* - ``start``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative exclusive lower bound of activity window.
* - ``stop``
- ``None``
- :math:`t_{\mathrm{stop,rel}}`
- Relative inclusive upper bound; ``None`` maps to :math:`+\infty`.
* - ``origin``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{origin}}`
- Global time offset added to ``start`` and ``stop``.
* - ``in_size``
- ``1``
- —
- Defines ``self.varshape`` and the total output train count.
* - ``individual_spike_trains``
- ``True``
- —
- Independent per-output renewal states vs. shared broadcast process.
* - ``rng_seed``
- ``0``
- —
- Seed for JAX random key initialization and splitting.
Raises
------
ValueError
If scalar-conversion fails due to non-scalar inputs; if
``0 <= amplitude <= rate`` is violated; if ``order < 1``; if
``stop < start``; or if finite ``origin`` / ``start`` / ``stop`` are
not multiples of the simulation resolution when ``dt`` is available.
TypeError
If provided values cannot be converted to numeric values or to the
required units (e.g. a non-convertible ``u.Hz`` or ``u.ms`` quantity).
KeyError
At runtime in :meth:`update`, if required simulation-context entries
(notably ``dt``) are unavailable from ``brainstate.environ``.
Notes
-----
- Hazard values are computed in ``float64``; tiny negative
:math:`\Lambda` values arising from floating-point roundoff are clamped
to zero before hazard evaluation.
- The value returned by :meth:`get_recorded_rate` is the step-end
instantaneous rate in spikes/s, matching NEST's ``rate`` recordable.
- Renewal state is revalidated against the timing grid whenever ``dt``
changes between :meth:`update` calls.
Examples
--------
Simulate a 2×3 array of independent sinusoidally modulated gamma trains:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... gen = brainpy.state.sinusoidal_gamma_generator(
... in_size=(2, 3),
... rate=50.0 * u.Hz,
... amplitude=20.0 * u.Hz,
... frequency=8.0 * u.Hz,
... phase=30.0,
... order=3.0,
... start=5.0 * u.ms,
... stop=80.0 * u.ms,
... rng_seed=9,
... )
... with brainstate.environ.context(t=12.0 * u.ms):
... spikes = gen.update()
... _ = spikes.shape
Use ``individual_spike_trains=False`` and update parameters at runtime:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... gen = brainpy.state.sinusoidal_gamma_generator(
... individual_spike_trains=False
... )
... gen.set(rate=40.0 * u.Hz, amplitude=10.0 * u.Hz, order=2.0)
... params = gen.get()
... _ = params['rate'], params['order']
See Also
--------
sinusoidal_poisson_generator : Sinusoidally modulated Poisson generator.
gamma_sup_generator : Superposition of independent gamma-renewal processes.
References
----------
.. [1] NEST source:
``models/sinusoidal_gamma_generator.h`` and
``models/sinusoidal_gamma_generator.cpp``.
.. [2] NEST docs:
https://nest-simulator.readthedocs.io/en/stable/models/sinusoidal_gamma_generator.html
.. [3] NEST source:
``nestkernel/stimulation_device.cpp``.
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size = 1,
rate: ArrayLike = 0. * u.Hz,
amplitude: ArrayLike = 0. * u.Hz,
frequency: ArrayLike = 0. * u.Hz,
phase: ArrayLike = 0.0,
order: ArrayLike = 1.0,
individual_spike_trains: bool = True,
start: ArrayLike = 0. * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0. * u.ms,
rng_seed: int = 0,
name: str | None = None,
):
super().__init__(in_size=in_size, name=name)
self.rate = self._to_scalar_rate_hz(rate)
self.amplitude = self._to_scalar_rate_hz(amplitude)
self.frequency = self._to_scalar_rate_hz(frequency)
self.phase = self._to_scalar_float(phase, name='phase')
self.order = self._to_scalar_float(order, name='order')
self.individual_spike_trains = bool(individual_spike_trains)
self.start = self._to_scalar_time_ms(start)
self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
self.origin = self._to_scalar_time_ms(origin)
self.rng_seed = int(rng_seed)
self._validate_parameters(
rate_hz=self.rate,
amplitude_hz=self.amplitude,
order=self.order,
start_ms=self.start,
stop_ms=self.stop,
)
self._num_targets = int(np.prod(self.varshape))
self._num_trains = self._num_targets if self.individual_spike_trains else 1
self._rate_per_ms = 0.0
self._amplitude_per_ms = 0.0
self._om_rad_per_ms = 0.0
self._phi_rad = 0.0
self._proc_params = (0.0, 0.0, 1.0, 0.0, 0.0)
self._proc_params_prev = self._proc_params
self._refresh_process_parameter_cache()
self._dt_cache_ms = np.nan
self._t_min_step = 0
self._t_max_step = np.iinfo(np.int64).max
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_timing_cache(dt_ms)
@staticmethod
def _to_scalar_time_ms(value: ArrayLike) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value), dtype=dftype)
if arr.size != 1:
raise ValueError('Time parameters must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _to_scalar_rate_hz(value: ArrayLike) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value), dtype=dftype)
if arr.size != 1:
raise ValueError('Rate parameters must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _to_scalar_float(value: ArrayLike, *, name: str) -> float:
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _time_to_step(time_ms: float, dt_ms: float) -> int:
return int(np.rint(time_ms / dt_ms))
@staticmethod
def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
if not np.isfinite(time_ms):
return
ratio = time_ms / dt_ms
nearest = np.rint(ratio)
if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
@staticmethod
def _validate_parameters(
*,
rate_hz: float,
amplitude_hz: float,
order: float,
start_ms: float,
stop_ms: float,
):
if order < 1.0:
raise ValueError('The gamma order must be at least 1.')
if not (0.0 <= amplitude_hz <= rate_hz):
raise ValueError('Rate parameters must fulfill 0 <= amplitude <= rate.')
if stop_ms < start_ms:
raise ValueError('stop >= start required.')
def _dt_ms(self) -> float:
dt = brainstate.environ.get_dt()
return self._to_scalar_time_ms(dt)
def _maybe_dt_ms(self) -> float | None:
dt = brainstate.environ.get('dt', default=None)
if dt is None:
return None
return self._to_scalar_time_ms(dt)
def _current_time_ms(self) -> float:
t = brainstate.environ.get('t', default=0. * u.ms)
if t is None:
return 0.0
return self._to_scalar_time_ms(t)
def _refresh_timing_cache(self, dt_ms: float):
self._assert_grid_time('origin', self.origin, dt_ms)
self._assert_grid_time('start', self.start, dt_ms)
self._assert_grid_time('stop', self.stop, dt_ms)
self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
if np.isfinite(self.stop):
self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
else:
self._t_max_step = np.iinfo(np.int64).max
self._dt_cache_ms = float(dt_ms)
def _refresh_process_parameter_cache(self):
self._rate_per_ms = self.rate / 1000.0
self._amplitude_per_ms = self.amplitude / 1000.0
self._om_rad_per_ms = self.frequency * (2.0 * math.pi / 1000.0)
self._phi_rad = self.phase * (math.pi / 180.0)
self._proc_params = (
self._om_rad_per_ms,
self._phi_rad,
self.order,
self._rate_per_ms,
self._amplitude_per_ms,
)
def _is_active(self, curr_step: int) -> bool:
return (self._t_min_step < curr_step) and (curr_step <= self._t_max_step)
@staticmethod
def _delta_lambda(params: tuple[float, float, float, float, float], t_a, t_b):
om, phi, order, rate, amplitude = params
dftype = brainstate.environ.dftype()
t_a_arr = np.asarray(t_a, dtype=dftype)
if t_a_arr.ndim == 0:
if float(t_a_arr) == float(t_b):
return np.asarray(0.0, dtype=dftype)
elif np.all(t_a_arr == float(t_b)):
return np.zeros_like(t_a_arr, dtype=dftype)
delta = order * rate * (t_b - t_a_arr)
if abs(amplitude) > 0.0 and abs(om) > 0.0:
delta += -order * amplitude / om * (
np.cos(om * t_b + phi) - np.cos(om * t_a_arr + phi)
)
return delta
def _accumulate_lambda_to_time(self, t_ms: float):
if self._num_trains == 0:
return
dftype = brainstate.environ.dftype()
t0 = np.asarray(self.t0_ms.value, dtype=dftype).reshape(-1).copy()
lam0 = np.asarray(self.Lambda_t0.value, dtype=dftype).reshape(-1).copy()
lam0 += np.asarray(self._delta_lambda(self._proc_params_prev, t0, t_ms), dtype=dftype)
t0.fill(t_ms)
self.t0_ms.value = t0
self.Lambda_t0.value = lam0
def _resize_train_state(self, now_ms: float, new_num_trains: int):
dftype = brainstate.environ.dftype()
old_t0 = np.asarray(self.t0_ms.value, dtype=dftype).reshape(-1)
old_lam = np.asarray(self.Lambda_t0.value, dtype=dftype).reshape(-1)
old_n = old_t0.size
if new_num_trains == old_n:
return
if new_num_trains < old_n:
self.t0_ms.value = old_t0[:new_num_trains].copy()
self.Lambda_t0.value = old_lam[:new_num_trains].copy()
return
add_n = new_num_trains - old_n
self.t0_ms.value = np.concatenate(
[old_t0, np.full(add_n, now_ms, dtype=dftype)]
)
self.Lambda_t0.value = np.concatenate(
[old_lam, np.zeros(add_n, dtype=dftype)]
)
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Initialize random key and per-train renewal state.
Allocates three :class:`brainstate.ShortTermState` variables:
- ``rng_key``: a JAX ``PRNGKey`` seeded from ``self.rng_seed``.
- ``t0_ms``: per-train renewal origin, initialized to the current
simulation time (``float64`` array of length ``self._num_trains``).
- ``Lambda_t0``: per-train accumulated scaled hazard at ``t0_ms``,
initialized to zero (``float64`` array of length ``self._num_trains``).
- ``_recorded_rate_hz``: cached step-end instantaneous rate in
spikes/s, initialized to ``0.0``.
The timing cache (``_t_min_step``, ``_t_max_step``) is also refreshed
if ``dt`` is available in the current ``brainstate.environ`` context.
Parameters
----------
batch_size : int or None, optional
Unused placeholder kept for :class:`brainstate.nn.Dynamics`
signature compatibility. Default is ``None``.
**kwargs
Unused extra keyword arguments; silently ignored.
Raises
------
ValueError
If finite ``origin`` / ``start`` / ``stop`` do not lie on the
simulation grid when ``dt`` is available.
"""
del batch_size, kwargs
self.rng_key = brainstate.ShortTermState(jax.random.PRNGKey(self.rng_seed))
curr_t_ms = self._current_time_ms()
dftype = brainstate.environ.dftype()
self.t0_ms = brainstate.ShortTermState(
np.full(self._num_trains, curr_t_ms, dtype=dftype)
)
self.Lambda_t0 = brainstate.ShortTermState(
np.zeros(self._num_trains, dtype=dftype)
)
self._recorded_rate_hz = brainstate.ShortTermState(jnp.asarray(0.0, dtype=dftype))
self._proc_params_prev = self._proc_params
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_timing_cache(dt_ms)
[docs]
def set(
self,
*,
rate: ArrayLike | object = _UNSET,
amplitude: ArrayLike | object = _UNSET,
frequency: ArrayLike | object = _UNSET,
phase: ArrayLike | object = _UNSET,
order: ArrayLike | object = _UNSET,
individual_spike_trains: bool | object = _UNSET,
start: ArrayLike | object = _UNSET,
stop: ArrayLike | object = _UNSET,
origin: ArrayLike | object = _UNSET,
):
r"""Update public parameters and refresh the process and timing caches.
Any parameter left at its sentinel value ``_UNSET`` retains its
current value. When called after :meth:`init_state`, the internal
renewal state is first advanced to the current simulation time using
the *previous* process parameters before switching to the new ones,
preserving the piecewise-integral semantics described in the class
docstring. If ``individual_spike_trains`` changes in a way that alters
the required number of trains, ``t0_ms`` and ``Lambda_t0`` are
grown (new trains start fresh) or truncated accordingly.
Parameters
----------
rate : ArrayLike or None, optional
New scalar mean rate in spikes/s (Hz), shape ``()`` after
conversion. ``_UNSET`` retains the current value.
amplitude : ArrayLike or None, optional
New scalar modulation amplitude in spikes/s (Hz), shape ``()``
after conversion. ``_UNSET`` retains the current value.
frequency : ArrayLike or None, optional
New scalar modulation frequency in Hz, shape ``()`` after
conversion. ``_UNSET`` retains the current value.
phase : ArrayLike or None, optional
New scalar modulation phase in degrees, shape ``()`` after
conversion. ``_UNSET`` retains the current value.
order : ArrayLike or None, optional
New scalar gamma order :math:`k`, shape ``()`` after conversion.
``_UNSET`` retains the current value.
individual_spike_trains : bool or None, optional
New spike-generation mode. ``_UNSET`` retains the current value.
start : ArrayLike or None, optional
New scalar relative activation start time in ms, shape ``()``
after conversion. ``_UNSET`` retains the current value.
stop : ArrayLike, None, or sentinel, optional
New scalar relative stop time in ms, shape ``()`` after
conversion; explicit ``None`` maps to ``+inf``. ``_UNSET``
retains the current value.
origin : ArrayLike or None, optional
New scalar origin offset in ms, shape ``()`` after conversion.
``_UNSET`` retains the current value.
Raises
------
ValueError
If scalar conversion fails due to non-scalar inputs; if the
constraints ``order >= 1``, ``0 <= amplitude <= rate``, or
``stop >= start`` are violated after resolving new values; or if
finite timing parameters do not lie on the simulation grid when
``dt`` is available.
TypeError
If numeric or unit conversion fails for any supplied input.
"""
now_ms = self._current_time_ms() if hasattr(self, 't0_ms') else 0.0
if hasattr(self, 't0_ms'):
self._accumulate_lambda_to_time(now_ms)
new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate)
new_amplitude = (
self.amplitude if amplitude is _UNSET else self._to_scalar_rate_hz(amplitude)
)
new_frequency = (
self.frequency if frequency is _UNSET else self._to_scalar_rate_hz(frequency)
)
new_phase = self.phase if phase is _UNSET else self._to_scalar_float(phase, name='phase')
new_order = self.order if order is _UNSET else self._to_scalar_float(order, name='order')
new_individual = (
self.individual_spike_trains
if individual_spike_trains is _UNSET
else bool(individual_spike_trains)
)
new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start)
if stop is _UNSET:
new_stop = self.stop
elif stop is None:
new_stop = np.inf
else:
new_stop = self._to_scalar_time_ms(stop)
new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin)
self._validate_parameters(
rate_hz=new_rate,
amplitude_hz=new_amplitude,
order=new_order,
start_ms=new_start,
stop_ms=new_stop,
)
self.rate = new_rate
self.amplitude = new_amplitude
self.frequency = new_frequency
self.phase = new_phase
self.order = new_order
self.individual_spike_trains = new_individual
self.start = new_start
self.stop = new_stop
self.origin = new_origin
self._num_trains = self._num_targets if self.individual_spike_trains else 1
self._refresh_process_parameter_cache()
if hasattr(self, 't0_ms'):
self._resize_train_state(now_ms, self._num_trains)
self._proc_params_prev = self._proc_params
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_timing_cache(dt_ms)
[docs]
def get(self) -> dict:
r"""Return current public parameters as plain Python scalars.
Returns
-------
out : dict
Mapping of parameter names to their current values. Keys and
value types are:
- ``'rate'`` : ``float`` — mean firing rate in spikes/s.
- ``'amplitude'`` : ``float`` — sinusoidal modulation amplitude
in spikes/s.
- ``'frequency'`` : ``float`` — modulation frequency in Hz.
- ``'phase'`` : ``float`` — modulation phase in degrees.
- ``'order'`` : ``float`` — gamma renewal order :math:`k`.
- ``'individual_spike_trains'`` : ``bool`` — spike-generation
mode flag.
- ``'start'`` : ``float`` — relative activation start in ms.
- ``'stop'`` : ``float`` — relative deactivation stop in ms
(``float('inf')`` when no stop was set).
- ``'origin'`` : ``float`` — time-origin offset in ms.
"""
return {
'rate': float(self.rate),
'frequency': float(self.frequency),
'phase': float(self.phase),
'amplitude': float(self.amplitude),
'order': float(self.order),
'individual_spike_trains': bool(self.individual_spike_trains),
'start': float(self.start),
'stop': float(self.stop),
'origin': float(self.origin),
}
[docs]
def get_recorded_rate(self) -> float:
r"""Return the latest step-end instantaneous rate in spikes/s.
The value is updated by :meth:`update` at each simulation step to
:math:`\lambda(t_{\mathrm{eval}}) \times 1000` spikes/s, where
:math:`t_{\mathrm{eval}} = (\mathrm{step} + 1) \times dt` is the
right edge of the current step. This matches NEST's ``rate``
recordable quantity.
Returns
-------
out : float
Most recently cached instantaneous rate in spikes/s (``float64``
scalar). Returns ``0.0`` if :meth:`init_state` has not been
called yet.
"""
if not hasattr(self, '_recorded_rate_hz'):
return 0.0
dftype = brainstate.environ.dftype()
return float(np.asarray(self._recorded_rate_hz.value, dtype=dftype).reshape(()))
def _sample_uniform(self, shape=()):
key, subkey = jax.random.split(self.rng_key.value)
self.rng_key.value = key
dftype = brainstate.environ.dftype()
return jax.random.uniform(subkey, shape=shape, dtype=dftype)
def _compute_hazard(self, lambda_val: np.ndarray, rate_per_ms: float, dt_ms: float) -> np.ndarray:
dftype = brainstate.environ.dftype()
hazard = np.zeros_like(lambda_val, dtype=dftype)
# Guard tiny negative values caused by floating-point roundoff only.
tiny_neg = np.logical_and(lambda_val < 0.0, lambda_val > -1e-15)
if np.any(tiny_neg):
lambda_val = lambda_val.copy()
lambda_val[tiny_neg] = 0.0
valid = lambda_val >= 0.0
if not np.any(valid):
return hazard
lam = lambda_val[valid]
q = np.asarray(
jax.lax.igammac(
jnp.asarray(self.order, dtype=dftype),
jnp.asarray(lam, dtype=dftype),
),
dtype=dftype,
)
denom = math.gamma(self.order) * q
numer = (
dt_ms
* self.order
* rate_per_ms
* np.power(lam, self.order - 1.0)
* np.exp(-lam)
)
hazard_valid = np.divide(
numer,
denom,
out=np.zeros_like(numer, dtype=dftype),
where=denom > 0.0,
)
hazard[valid] = hazard_valid
return hazard
@staticmethod
def _delta_lambda_jax(params: tuple, t_a, t_b):
"""JAX-traceable version of _delta_lambda; works under jax.jit / for_loop."""
om, phi, order, rate, amplitude = params
delta = order * rate * (t_b - t_a)
if abs(amplitude) > 0.0 and abs(om) > 0.0:
delta = delta - order * amplitude / om * (
jnp.cos(om * t_b + phi) - jnp.cos(om * t_a + phi)
)
return delta
def _compute_hazard_jax(self, lambda_val, rate_per_ms, dt_ms: float):
"""JAX-traceable version of _compute_hazard; works under jax.jit / for_loop."""
dftype = brainstate.environ.dftype()
lam = jnp.asarray(lambda_val, dtype=dftype)
# Clamp tiny negatives (floating-point roundoff); large negatives stay negative.
lam_clamped = jnp.where((lam < 0.0) & (lam > -1e-15), jnp.zeros_like(lam), lam)
lam_safe = jnp.maximum(lam_clamped, 0.0)
q = jax.lax.igammac(jnp.asarray(self.order, dtype=dftype), lam_safe)
denom = math.gamma(self.order) * q
numer = (
dt_ms
* self.order
* rate_per_ms
* jnp.power(lam_safe, self.order - 1.0)
* jnp.exp(-lam_safe)
)
hazard = jnp.where(denom > 0.0, numer / denom, jnp.zeros_like(numer))
# Zero out hazard for large-negative lambda (not just rounding noise).
return jnp.where(lam_clamped >= 0.0, hazard, jnp.zeros_like(hazard))
[docs]
def update(self):
r"""Advance one simulation step and emit binary spike events.
Reads the current time ``t`` and resolution ``dt`` from
``brainstate.environ``, lazily calls :meth:`init_state` if state has
not been allocated, and refreshes the timing cache if ``dt`` has
changed since the last call. The step-end time
``t_eval = (step + 1) * dt`` is used for rate evaluation and
:math:`\Lambda` accumulation. Trains outside the active window or
with :math:`\lambda(t_{\mathrm{eval}}) \le 0` receive zero spikes
without consuming random draws.
Returns
-------
out : jax.Array
``int64`` JAX array with shape ``self.varshape``. Each element is
``0`` or ``1``, giving the binary spike decision for the
corresponding output train in the current step. When
``individual_spike_trains=False``, all elements share the same
value.
Raises
------
KeyError
If ``dt`` is not available in the current ``brainstate.environ``
context.
ValueError
If timing-parameter grid validation fails after the simulation
resolution changes between calls.
"""
if not hasattr(self, 'rng_key'):
self.init_state()
dt_ms = self._dt_ms()
if (not np.isfinite(self._dt_cache_ms)) or (
not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15)
):
self._refresh_timing_cache(dt_ms)
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
# Get current time as a JAX-compatible scalar so this method works under
# jax.jit / brainstate.transform.for_loop tracing.
t = brainstate.environ.get('t', default=0. * u.ms)
if isinstance(t, u.Quantity):
t_ms = t.to_decimal(u.ms)
else:
t_ms = jnp.asarray(t, dtype=dftype)
curr_step = jnp.rint(t_ms / dt_ms).astype(jnp.int64)
t_eval_ms = (curr_step + 1) * dt_ms
# Instantaneous rate at t_eval (jnp.sin handles traced t_eval_ms).
sin_val = jnp.sin(
jnp.asarray(self._om_rad_per_ms * t_eval_ms + self._phi_rad, dtype=dftype)
)
rate_per_ms = self._rate_per_ms + self._amplitude_per_ms * sin_val
# Cache the step-end rate (always updated, even during inactivity).
self._recorded_rate_hz.value = jnp.asarray(rate_per_ms * 1000.0, dtype=dftype)
# Static early exits that don't depend on traced values.
if self._num_trains == 0:
return jnp.zeros(self.varshape, dtype=ditype)
if self._rate_per_ms == 0.0 and self._amplitude_per_ms == 0.0:
return jnp.zeros(self.varshape, dtype=ditype)
# JAX-compatible activity check (works with traced curr_step).
is_active = (self._t_min_step < curr_step) & (curr_step <= self._t_max_step)
# Fetch renewal state as JAX arrays.
t0 = jnp.asarray(self.t0_ms.value, dtype=dftype)
lam0 = jnp.asarray(self.Lambda_t0.value, dtype=dftype)
delta = self._delta_lambda_jax(self._proc_params, t0, t_eval_ms)
lambda_eval = lam0 + delta
hazard = self._compute_hazard_jax(lambda_eval, rate_per_ms, dt_ms)
if self.individual_spike_trains:
draws = self._sample_uniform(shape=(self._num_trains,))
spikes = draws < hazard
# Only reset renewal state for trains that spiked AND are active.
active_spikes = jnp.where(is_active, spikes, jnp.zeros_like(spikes))
self.t0_ms.value = jnp.where(active_spikes, jnp.full_like(t0, t_eval_ms), t0)
self.Lambda_t0.value = jnp.where(active_spikes, jnp.zeros_like(lam0), lam0)
return jnp.asarray(active_spikes.reshape(self.varshape), dtype=ditype)
draw = self._sample_uniform(shape=())
spike = draw < hazard[0]
active_spike = is_active & spike
self.t0_ms.value = jnp.where(active_spike, jnp.full_like(t0, t_eval_ms), t0)
self.Lambda_t0.value = jnp.where(active_spike, jnp.zeros_like(lam0), lam0)
spike_val = jnp.asarray(active_spike, dtype=ditype)
return jnp.full(self.varshape, spike_val, dtype=ditype)