# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from collections import deque
from typing import Sequence
import brainstate
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'pulsepacket_generator',
]
_UNSET = object()
class pulsepacket_generator(NESTDevice):
r"""Gaussian pulse-packet spike generator compatible with NEST.
Description
-----------
``pulsepacket_generator`` re-implements NEST's stimulation device with the
same name and emits integer per-step spike multiplicities.
**1. Pulse model and grid projection**
For each configured pulse center :math:`t_c` (ms), this model generates
exactly ``activity`` sampled spike times per output generator:
.. math::
x_{i,j} \sim \mathcal{N}(t_c, \mathrm{sdev}^2),
\quad i=1,\dots,N,\ j=1,\dots,\mathrm{activity},
where :math:`N=\prod \mathrm{varshape}` is the number of independent
generators. For ``sdev == 0``, the Gaussian draw degenerates to the
deterministic value :math:`x_{i,j}=t_c`.
Sampled times are converted to NEST-like integer tics and delivery steps:
.. math::
\tau = \left\lfloor x \cdot 1000 + 0.5 \right\rfloor,\qquad
k = \left\lceil \tau / \Delta\tau \right\rceil,
where :math:`\Delta\tau` is the resolution in tics per simulation step.
Samples with ``tau < tau_now`` are discarded; the remaining samples are
queued and emitted as multiplicity counts at their delivery steps.
**2. NEST update ordering (source-equivalent)**
This implementation mirrors ``models/pulsepacket_generator.cpp``:
1. Keep indices ``start_center_idx``/``stop_center_idx`` into sorted
``pulse_times`` for a moving window of centers around current time.
2. At each update step, extend the right edge of that center window while
``center_time - t <= tolerance``.
3. For each newly entered center, sample ``activity`` Gaussian times,
keep only samples with ``sample_time >= t``, convert them to integer
steps, and append to a per-generator queue.
4. Sort each queue.
5. Emit (pop) all queued spikes whose integer step is in the current
delivery interval and return per-step multiplicity.
As in NEST, ``tolerance = sdev * 10`` for ``sdev > 0`` and
``tolerance = 1.0 ms`` otherwise.
**3. Timing semantics (CURRENT_GENERATOR shift)**
NEST classifies this model as ``CURRENT_GENERATOR`` in
``get_type()``. Therefore activity is evaluated with the
``StimulationDevice`` current-generator shift:
.. math::
t_{\min} < (n + 2) \le t_{\max},
where ``n`` is the current simulation step and
``t_min = origin + start``, ``t_max = origin + stop`` (in steps).
This differs from regular spike generators and is intentionally preserved
for behavioral parity.
**4. Assumptions, constraints, and computational implications**
Enforced constraints:
- ``activity`` is an integer scalar with ``activity >= 0``.
- ``sdev`` is a scalar in ms with ``sdev >= 0``.
- ``stop >= start`` after scalar conversion.
- ``sdev_tolerance > 0``.
Runtime constraints:
- If ``dt`` is available, finite ``origin``, ``start``, and ``stop`` must
be exact grid multiples (absolute tolerance ``1e-12`` in ``time / dt``).
- ``pulse_times`` are flattened to 1-D and sorted ascending before use.
Computational implications:
- Let ``C_new`` be newly entered pulse centers in one step. New pulse
generation costs
:math:`O(C_{\mathrm{new}} \cdot N \cdot \mathrm{activity})` sampling
plus per-queue sort when new events are appended.
- Emission costs :math:`O(N + M_{\mathrm{pop}})` where
:math:`M_{\mathrm{pop}}` is number of emitted queued spikes in the step.
- Memory is proportional to the total number of queued future spikes
across all output generators.
Parameters
----------
in_size : Size, optional
Output size specification consumed by
:class:`brainstate.nn.Dynamics`. ``self.varshape`` derived from this
value is the exact shape returned by :meth:`update`. Each element is
one independent output generator. Default is ``1``.
pulse_times : Sequence[ArrayLike] or ArrayLike or None, optional
Pulse center times in ms. Accepted inputs are any array-like values
flattenable to shape ``(K,)`` after conversion, or a
:class:`saiunit.Quantity` convertible to ``u.ms``.
``None`` creates an empty schedule. Values are sorted internally in
ascending order. Default is ``None``.
activity : ArrayLike, optional
Scalar integer count per pulse center, shape ``()`` after conversion.
Parsed through nearest-integer check with absolute tolerance
``1e-12`` and must satisfy ``activity >= 0``. Default is ``0``.
sdev : ArrayLike, optional
Scalar standard deviation in ms, shape ``()`` after conversion.
Accepts unitful time convertible to ``u.ms`` or scalar numeric.
Must satisfy ``sdev >= 0``. Default is ``0.0 * u.ms``.
start : ArrayLike, optional
Scalar relative start time in ms, shape ``()`` after conversion.
Effective lower bound is ``origin + start`` under current-generator
semantics and must be grid-representable when ``dt`` is available.
Default is ``0.0 * u.ms``.
stop : ArrayLike or None, optional
Scalar relative stop time in ms, shape ``()`` after conversion.
``None`` maps to ``+inf``. When finite, must satisfy ``stop >= start``
and be grid-representable when ``dt`` is available.
Default is ``None``.
origin : ArrayLike, optional
Scalar origin offset in ms, shape ``()`` after conversion.
Added to ``start`` and ``stop`` to form absolute activity bounds.
Must be grid-representable when finite and ``dt`` is available.
Default is ``0.0 * u.ms``.
rng_seed : int, optional
Seed used to initialize ``numpy.random.default_rng`` in
:meth:`init_state`. Default is ``0``.
sdev_tolerance : float, optional
Positive multiplicative factor used to compute tolerance window
``sdev * sdev_tolerance`` when ``sdev > 0``. NEST default is ``10.0``.
name : str, optional
Optional node name passed to :class:`brainstate.nn.Dynamics`.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 20 18 24 38
* - Parameter
- Default
- Math symbol
- Semantics
* - ``pulse_times``
- ``None``
- :math:`t_c`
- Pulse-center schedule in ms (internally sorted ascending).
* - ``activity``
- ``0``
- :math:`n_{\mathrm{spk}}`
- Number of sampled spikes generated per center and output train.
* - ``sdev``
- ``0.0 * u.ms``
- :math:`\sigma_t`
- Temporal jitter standard deviation in ms.
* - ``start``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative lower activity bound (with CURRENT_GENERATOR shift).
* - ``stop``
- ``None``
- :math:`t_{\mathrm{stop,rel}}`
- Relative upper activity bound; ``None`` maps to ``+\infty``.
* - ``origin``
- ``0.0 * u.ms``
- :math:`t_0`
- Global offset added to ``start`` and ``stop``.
* - ``sdev_tolerance``
- ``10.0``
- :math:`\kappa`
- Tolerance factor for center-window inclusion, ``\kappa \sigma_t``.
* - ``in_size``
- ``1``
- -
- Defines ``self.varshape`` / number of independent generators.
* - ``rng_seed``
- ``0``
- -
- Seed for NumPy RNG used for Gaussian pulse sampling.
Raises
------
ValueError
If ``activity`` is negative or non-integral; if ``sdev`` is negative;
if ``stop < start``; if ``sdev_tolerance <= 0``; if scalar conversion
fails due to non-scalar input shape; if backend data has fewer than
three values; if finite ``origin``/``start``/``stop`` are not grid
multiples when ``dt`` is available; or if simulation resolution is
non-positive.
TypeError
If time-valued arguments cannot be converted to ``u.ms``-compatible
values or numeric arrays.
KeyError
At runtime, if required simulation context entries (for example
``dt`` from ``brainstate.environ.get_dt()``) are unavailable.
Notes
-----
- ``set(activity=...)`` and ``set(sdev=...)`` trigger pulse
re-generation behavior by clearing queued spikes, matching NEST.
- Stimulation-backend parameter order in NEST is
``[activity, sdev_ms, pulse_time_0_ms, ...]`` and is exposed via
:meth:`set_data_from_stimulation_backend`.
- Pulse times that are too far in the past (``sample_time < t``) are
silently discarded during generation; no error is raised.
- Outputs are integer multiplicities ``0, 1, 2, ...`` per step,
matching NEST ``SpikeEvent`` multiplicity semantics rather than
binary spike flags.
See Also
--------
poisson_generator : Independent Poisson spike trains at fixed rate.
mip_generator : Correlated spike trains via Multiple Interaction Process.
inhomogeneous_poisson_generator : Poisson generator with time-varying rate.
gamma_sup_generator : Superposition of stationary gamma-process trains.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... gen = brainpy.state.pulsepacket_generator(
... in_size=(2, 3),
... pulse_times=[10.0 * u.ms, 20.0 * u.ms],
... activity=5,
... sdev=1.5 * u.ms,
... start=0.0 * u.ms,
... stop=40.0 * u.ms,
... rng_seed=7,
... )
... with brainstate.environ.context(t=10.0 * u.ms):
... counts = gen.update()
... _ = counts.shape
.. code-block:: python
>>> import brainpy
>>> import saiunit as u
>>> gen = brainpy.state.pulsepacket_generator(activity=3, sdev=0.5 * u.ms)
>>> gen.set_data_from_stimulation_backend([4.0, 0.8, 5.0, 15.0, 25.0])
>>> params = gen.get()
>>> _ = params['activity'], params['pulse_times']
References
----------
.. [1] NEST source: ``models/pulsepacket_generator.cpp`` and
``models/pulsepacket_generator.h``.
.. [2] NEST source: ``nestkernel/stimulation_device.cpp``.
.. [3] NEST model docs:
https://nest-simulator.readthedocs.io/en/stable/models/pulsepacket_generator.html
"""
__module__ = 'brainpy.state'
_TICS_PER_MS = 1000.0
def __init__(
self,
in_size: Size = 1,
pulse_times: Sequence[ArrayLike] | ArrayLike | None = None,
activity: ArrayLike = 0,
sdev: ArrayLike = 0. * u.ms,
start: ArrayLike = 0. * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0. * u.ms,
rng_seed: int = 0,
sdev_tolerance: float = 10.0,
name: str | None = None,
):
super().__init__(in_size=in_size, name=name)
self.activity = self._to_scalar_int(activity, name='activity')
self.sdev = self._to_scalar_time_ms(sdev)
self.start = self._to_scalar_time_ms(start)
self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
self.origin = self._to_scalar_time_ms(origin)
self.rng_seed = int(rng_seed)
self.sdev_tolerance = float(sdev_tolerance)
if self.sdev_tolerance <= 0.0:
raise ValueError('sdev_tolerance must be positive.')
dftype = brainstate.environ.dftype()
self._pulse_times_ms = np.asarray([], dtype=dftype)
if pulse_times is not None:
self._pulse_times_ms = np.sort(self._to_time_array_ms(pulse_times))
self._validate_parameters(
activity=self.activity,
sdev=self.sdev,
start=self.start,
stop=self.stop,
)
self._num_generators = int(np.prod(self.varshape))
self._dt_cache_ms = np.nan
self._dt_tics = 0
self._t_min_step = 0
self._t_max_step = np.iinfo(np.int64).max
self._tolerance_ms = 1.0
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_runtime_cache(dt_ms)
@staticmethod
def _to_scalar_time_ms(value: ArrayLike) -> float:
if isinstance(value, u.Quantity):
dftype = brainstate.environ.dftype()
arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError('Time parameters must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _to_time_array_ms(values: Sequence[ArrayLike] | ArrayLike) -> np.ndarray:
dftype = brainstate.environ.dftype()
if not isinstance(values, u.Quantity):
arr0 = np.asarray(values)
if arr0.size == 0:
return np.asarray([], dtype=dftype)
if isinstance(values, u.Quantity):
arr = values.to_decimal(u.ms)
else:
arr = u.math.asarray(values, dtype=dftype)
return np.asarray(arr, dtype=dftype).reshape(-1)
@staticmethod
def _to_scalar_int(value: ArrayLike, *, name: str) -> int:
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value, dtype=dftype), dtype=dftype)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
scalar = float(arr.reshape(()))
nearest = np.rint(scalar)
if not math.isclose(scalar, nearest, rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be an integer.')
return int(nearest)
@staticmethod
def _validate_parameters(
*,
activity: int,
sdev: float,
start: float,
stop: float,
):
if activity < 0:
raise ValueError('The activity cannot be negative.')
if sdev < 0.0:
raise ValueError('The standard deviation cannot be negative.')
if stop < start:
raise ValueError('stop >= start required.')
@classmethod
def _ms_to_tics(cls, time_ms: float) -> int:
# Match NEST Time(ms): static_cast<long>(ms * TICS_PER_MS + 0.5).
return int(time_ms * cls._TICS_PER_MS + 0.5)
@staticmethod
def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
if not np.isfinite(time_ms):
return
ratio = time_ms / dt_ms
nearest = np.rint(ratio)
if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
def _dt_ms(self) -> float:
dt = brainstate.environ.get_dt()
return self._to_scalar_time_ms(dt)
def _maybe_dt_ms(self) -> float | None:
dt = brainstate.environ.get('dt', default=None)
if dt is None:
return None
return self._to_scalar_time_ms(dt)
def _current_time_ms(self) -> float:
t = brainstate.environ.get('t', default=0. * u.ms)
if t is None:
return 0.0
return self._to_scalar_time_ms(t)
def _time_to_step(self, time_ms: float, dt_ms: float) -> int:
return int(np.rint(time_ms / dt_ms))
def _time_to_delivery_step(self, time_ms: float) -> int:
tic = self._ms_to_tics(time_ms)
if self._dt_tics <= 0:
return 0
return int(math.ceil(float(tic) / float(self._dt_tics)))
def _refresh_runtime_cache(self, dt_ms: float):
self._assert_grid_time('origin', self.origin, dt_ms)
self._assert_grid_time('start', self.start, dt_ms)
self._assert_grid_time('stop', self.stop, dt_ms)
self._dt_tics = int(np.rint(dt_ms * self._TICS_PER_MS))
if self._dt_tics <= 0:
raise ValueError('Simulation resolution must be positive.')
self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
if np.isfinite(self.stop):
self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
else:
self._t_max_step = np.iinfo(np.int64).max
if self.sdev > 0.0:
self._tolerance_ms = self.sdev * self.sdev_tolerance
else:
self._tolerance_ms = 1.0
self._dt_cache_ms = float(dt_ms)
def _is_active(self, curr_step: int) -> bool:
shifted_step = curr_step + 2
return (self._t_min_step < shifted_step) and (shifted_step <= self._t_max_step)
def _clear_spike_queues(self):
if hasattr(self, '_spike_queues'):
for i in range(len(self._spike_queues)):
self._spike_queues[i].clear()
def _all_queues_empty(self) -> bool:
return all(len(q) == 0 for q in self._spike_queues)
def _pre_run_hook(self, now_ms: float):
assert self._start_center_idx <= self._stop_center_idx
self._start_center_idx = 0
self._stop_center_idx = 0
now_tic = self._ms_to_tics(now_ms)
n_centers = self._pulse_times_ms.size
while self._stop_center_idx < n_centers:
center_tic = self._ms_to_tics(float(self._pulse_times_ms[self._stop_center_idx]))
if ((center_tic - now_tic) / self._TICS_PER_MS) > self._tolerance_ms:
break
if (abs(center_tic - now_tic) / self._TICS_PER_MS) > self._tolerance_ms:
self._start_center_idx += 1
self._stop_center_idx += 1
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Initialize runtime state for stochastic pulse generation.
Parameters
----------
batch_size : int or None, optional
Unused by this implementation. Present to match the base-class
interface. Default is ``None``.
**kwargs
Additional unused keyword arguments accepted for interface
compatibility.
Raises
------
ValueError
If ``dt`` is available and finite timing parameters are not grid
multiples, or if computed simulation resolution is non-positive.
TypeError
If environment times cannot be converted to scalar milliseconds.
Notes
-----
Re-initialization resets queues and deterministic random state from
``rng_seed``; pending queued spikes are discarded.
"""
del batch_size, kwargs
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_runtime_cache(dt_ms)
self._rng = np.random.default_rng(self.rng_seed)
self._spike_queues = [deque() for _ in range(self._num_generators)]
self._start_center_idx = 0
self._stop_center_idx = 0
self._pre_run_hook(self._current_time_ms())
[docs]
def set(
self,
*,
pulse_times: Sequence[ArrayLike] | ArrayLike | object = _UNSET,
activity: ArrayLike | object = _UNSET,
sdev: ArrayLike | object = _UNSET,
start: ArrayLike | object = _UNSET,
stop: ArrayLike | object = _UNSET,
origin: ArrayLike | object = _UNSET,
):
r"""Set public parameters with NEST-compatible update semantics.
Parameters
----------
pulse_times : Sequence[ArrayLike] or ArrayLike or object, optional
New pulse-center schedule in ms. Any provided value is converted to
a flattened ``float64`` array and sorted ascending. Pass ``_UNSET``
(default) to keep current pulse times.
activity : ArrayLike or object, optional
New scalar integer spikes-per-center value, shape ``()`` after
conversion, with ``activity >= 0``. Pass ``_UNSET`` to keep the
current value.
sdev : ArrayLike or object, optional
New scalar temporal jitter standard deviation in ms, shape ``()``
after conversion, with ``sdev >= 0``. Pass ``_UNSET`` to keep the
current value.
start : ArrayLike or object, optional
New scalar relative start time in ms, shape ``()`` after
conversion. Pass ``_UNSET`` to keep the current value.
stop : ArrayLike or None or object, optional
New scalar relative stop time in ms, shape ``()`` after conversion.
``None`` maps to ``+inf``. Pass ``_UNSET`` to keep the current
value.
origin : ArrayLike or object, optional
New scalar origin offset in ms, shape ``()`` after conversion.
Pass ``_UNSET`` to keep the current value.
Raises
------
ValueError
If integer/scalar validation fails, if ``activity < 0``,
``sdev < 0``, ``stop < start``, or if finite time bounds are not
aligned to the simulation grid when ``dt`` is available.
TypeError
If provided values cannot be converted to expected numeric/time
forms.
Notes
-----
Matching NEST behavior, changing either ``activity`` or ``sdev``
triggers pulse re-generation state reset by clearing queued spikes.
"""
new_activity = (
self.activity
if activity is _UNSET
else self._to_scalar_int(activity, name='activity')
)
new_sdev = self.sdev if sdev is _UNSET else self._to_scalar_time_ms(sdev)
new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start)
if stop is _UNSET:
new_stop = self.stop
elif stop is None:
new_stop = np.inf
else:
new_stop = self._to_scalar_time_ms(stop)
new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin)
self._validate_parameters(
activity=new_activity,
sdev=new_sdev,
start=new_start,
stop=new_stop,
)
need_new_pulse = (new_activity != self.activity) or (
not math.isclose(new_sdev, self.sdev, rel_tol=0.0, abs_tol=0.0))
if pulse_times is _UNSET:
new_pulse_times = self._pulse_times_ms.copy()
else:
new_pulse_times = self._to_time_array_ms(pulse_times)
if pulse_times is not _UNSET or need_new_pulse:
dftype = brainstate.environ.dftype()
new_pulse_times = np.sort(np.asarray(new_pulse_times, dtype=dftype).reshape(-1))
self._pulse_times_ms = new_pulse_times
self._clear_spike_queues()
self.activity = new_activity
self.sdev = new_sdev
self.start = new_start
self.stop = new_stop
self.origin = new_origin
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_runtime_cache(dt_ms)
[docs]
def get(self) -> dict:
r"""Return current public parameter values.
Returns
-------
out : dict
``dict`` with keys ``pulse_times``, ``activity``, ``sdev``,
``start``, ``stop``, and ``origin``. Time values are returned in
milliseconds as Python ``float`` values, and ``pulse_times`` is a
Python ``list[float]``.
"""
return {
'pulse_times': self._pulse_times_ms.tolist(),
'activity': int(self.activity),
'sdev': float(self.sdev),
'start': float(self.start),
'stop': float(self.stop),
'origin': float(self.origin),
}
[docs]
def set_data_from_stimulation_backend(self, input_param: Sequence[float] | np.ndarray):
r"""Update parameters from stimulation-backend payload.
Parameters
----------
input_param : Sequence[float] or numpy.ndarray
One-dimensional backend payload with shape ``(M,)`` and
``M >= 3`` in NEST order:
``[activity, sdev_ms, pulse_time_0_ms, ...]``. Entries are parsed
as ``float64``. ``sdev`` and ``pulse_times`` are interpreted in ms.
Raises
------
ValueError
If payload length is between 1 and 2 (inclusive), since at least
``activity``, ``sdev``, and one pulse time are required by this
backend contract.
TypeError
If payload cannot be cast to numeric ``float64`` values.
"""
dftype = brainstate.environ.dftype()
data = np.asarray(input_param, dtype=dftype).reshape(-1)
if data.size == 0:
return
if data.size < 3:
raise ValueError(
'The size of the data for pulsepacket_generator must be at least 3 '
'[activity, sdev, pulse_times...].'
)
self.set(
activity=data[0],
sdev=data[1] * u.ms,
pulse_times=data[2:] * u.ms,
)
def _generate_new_pulses(self, curr_tic: int):
if self._start_center_idx >= self._stop_center_idx or self.activity <= 0:
return
need_sort = False
while self._start_center_idx < self._stop_center_idx:
center = float(self._pulse_times_ms[self._start_center_idx])
if self.sdev > 0.0:
sampled = self._rng.normal(
loc=center,
scale=self.sdev,
size=(self._num_generators, self.activity),
)
else:
dftype = brainstate.environ.dftype()
sampled = np.full(
(self._num_generators, self.activity),
center,
dtype=dftype,
)
for i in range(self._num_generators):
queue_i = self._spike_queues[i]
for x in sampled[i]:
x_tic = self._ms_to_tics(float(x))
if x_tic >= curr_tic:
queue_i.append(self._time_to_delivery_step(float(x)))
need_sort = True
self._start_center_idx += 1
if need_sort:
for i in range(self._num_generators):
q = self._spike_queues[i]
if len(q) > 1:
self._spike_queues[i] = deque(sorted(q))
[docs]
def update(self):
r"""Advance one simulation step and emit spike multiplicities.
Returns
-------
out : jax.Array
JAX array of dtype ``int64`` and shape ``self.varshape``.
Each element is the number of spikes emitted by one output
generator in the current step. Returns all zeros when inactive or
when no spikes are due.
Raises
------
ValueError
If runtime ``dt`` is non-positive, if finite activity bounds are
not grid multiples, or if cached time-step conversion becomes
invalid.
TypeError
If runtime time values cannot be converted to scalar milliseconds.
KeyError
If required simulation context entries are missing.
Notes
-----
If state has not been initialized explicitly, :meth:`update` performs
lazy initialization by calling :meth:`init_state`.
"""
if not hasattr(self, '_rng'):
self.init_state()
dt_ms = self._dt_ms()
if (not np.isfinite(self._dt_cache_ms)) or (
not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15)
):
self._refresh_runtime_cache(dt_ms)
ditype = brainstate.environ.ditype()
curr_t_ms = self._current_time_ms()
curr_step = self._time_to_step(curr_t_ms, dt_ms)
if (
(self._start_center_idx == self._pulse_times_ms.size and self._all_queues_empty())
or (not self._is_active(curr_step))
):
return jnp.zeros(self.varshape, dtype=ditype)
curr_tic = self._ms_to_tics(curr_t_ms)
n_centers = self._pulse_times_ms.size
while self._stop_center_idx < n_centers:
center_tic = self._ms_to_tics(float(self._pulse_times_ms[self._stop_center_idx]))
if ((center_tic - curr_tic) / self._TICS_PER_MS) > self._tolerance_ms:
break
self._stop_center_idx += 1
self._generate_new_pulses(curr_tic)
step_limit = curr_step + 1
counts = np.zeros(self._num_generators, dtype=ditype)
for i in range(self._num_generators):
q = self._spike_queues[i]
n = 0
while len(q) > 0 and q[0] < step_limit:
q.popleft()
n += 1
counts[i] = n
return jnp.asarray(counts.reshape(self.varshape), dtype=ditype)