# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
import brainstate
import saiunit as u
import jax
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'poisson_generator',
]
_UNSET = object()
class poisson_generator(NESTDevice):
r"""Poisson spike generator compatible with NEST.
Description
-----------
``poisson_generator`` re-implements NEST's stimulation device of the same
name and emits per-step spike multiplicities.
**1. Point-process model and discretization**
Let ``r`` be the configured homogeneous rate in spikes/s and
:math:`\Delta t` be the simulation step in ms. For one output train, the
count in one discrete bin is sampled as
.. math::
K_n \sim \mathrm{Poisson}(\lambda_n), \qquad
\lambda_n = r \, \Delta t / 1000.
The factor ``1000`` converts milliseconds to seconds, so
:math:`\lambda_n` is dimensionless. This is the standard bin-count
reduction of a homogeneous Poisson process where
:math:`\mathbb{P}(K_n=k)=e^{-\lambda_n}\lambda_n^k/k!`.
Implementation detail: :meth:`update` draws one vectorized Poisson sample
with ``shape=self.varshape`` via ``jax.random.poisson``. Each element is an
independent train; values are integer multiplicities ``0, 1, 2, ...`` and
are not clipped to binary spikes.
**2. Activity window and NEST timing semantics**
The active interval follows NEST ``StimulationDevice::is_active`` for spike
generators:
.. math::
t_{\min} < t \le t_{\max}, \qquad
t_{\min} = origin + start,\quad t_{\max} = origin + stop.
Therefore ``start`` is exclusive and ``stop`` is inclusive.
Internally, times are projected to integer steps with
``round(time_ms / dt_ms)`` and activity is evaluated as
``t_min_step < curr_step <= t_max_step``.
**3. Assumptions, constraints, and computational implications**
Scalar parameters are converted to ``float64`` in public units (Hz or ms).
If ``dt`` is available, finite ``origin``, ``start``, and ``stop`` must lie
on the simulation grid (absolute tolerance ``1e-12`` in ``time/dt`` ratio).
Cache refresh is triggered when ``dt`` changes. Per-step runtime is
:math:`O(\prod \text{varshape})` for sampling and memory is proportional to
output size. When ``rate <= 0`` or inactive, the update path returns a
zero ``int64`` array without Poisson sampling.
Parameters
----------
in_size : Size, optional
Output size specification for :class:`brainstate.nn.Dynamics`.
The derived ``self.varshape`` is the exact shape of arrays returned by
:meth:`update`. Each element corresponds to one independent output
train. Default is ``1``.
rate : ArrayLike, optional
Scalar firing rate in spikes/s (Hz). Accepted forms are any
``ArrayLike`` with exactly one element, optionally a
:class:`saiunit.Quantity` convertible to ``u.Hz``.
Must satisfy ``rate >= 0``. Default is ``0.0 * u.Hz``.
start : ArrayLike, optional
Scalar relative start time in ms (exclusive lower bound after adding
``origin``). Must be scalar-convertible to ``float64`` and, when
``dt`` is available, grid representable. Default is ``0.0 * u.ms``.
stop : ArrayLike or None, optional
Scalar relative stop time in ms (inclusive upper bound after adding
``origin``). ``None`` is mapped to ``+inf``. If finite, must be
scalar-convertible and grid representable when ``dt`` is available.
Must satisfy ``stop >= start`` after conversion. Default is ``None``.
origin : ArrayLike, optional
Scalar time origin offset in ms added to both ``start`` and ``stop``.
Must be scalar-convertible and grid representable when ``dt`` is
available. Default is ``0.0 * u.ms``.
rng_seed : int, optional
Seed used to initialize ``jax.random.PRNGKey`` inside
:meth:`init_state`. Different seeds lead to different stochastic
realizations for otherwise identical parameters. Default is ``0``.
name : str or None, optional
Optional dynamics node name.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 22 18 18 42
* - Parameter
- Default
- Math symbol
- Semantics
* - ``rate``
- ``0.0 * u.Hz``
- :math:`r`
- Homogeneous firing rate in spikes/s.
* - ``start``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative exclusive lower bound of activity.
* - ``stop``
- ``None``
- :math:`t_{\mathrm{stop,rel}}`
- Relative inclusive upper bound; ``None`` maps to ``+\infty``.
* - ``origin``
- ``0.0 * u.ms``
- :math:`t_0`
- Global offset added to ``start`` and ``stop``.
* - ``in_size``
- ``1``
- -
- Defines ``self.varshape`` (number/shape of independent trains).
* - ``rng_seed``
- ``0``
- -
- Seed for JAX key state used by Poisson sampling.
Raises
------
ValueError
If ``rate < 0``; if ``stop < start``; if time/rate inputs are not
scalar-convertible; or if finite ``origin``/``start``/``stop`` are not
multiples of simulation resolution when ``dt`` is available.
TypeError
If unit conversion to ``u.Hz`` or ``u.ms`` fails for supplied inputs.
KeyError
At runtime, if required simulation context entries (for example ``dt``
via ``brainstate.environ.get_dt()``) are missing.
Notes
-----
- ``update`` lazily initializes RNG state if :meth:`init_state` has not
been called explicitly.
- Parameter updates through :meth:`set` recompute cached step bounds when
``dt`` is present in the environment.
- As in NEST, one generator can fan out to many targets while maintaining
independent trains per output element.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... gen = brainpy.state.poisson_generator(
... in_size=(2, 3),
... rate=1200.0 * u.Hz,
... start=5.0 * u.ms,
... stop=20.0 * u.ms,
... rng_seed=11,
... )
... with brainstate.environ.context(t=10.0 * u.ms):
... counts = gen.update()
... _ = counts.shape
.. code-block:: python
>>> import brainpy
>>> import saiunit as u
>>> gen = brainpy.state.poisson_generator(rate=500.0 * u.Hz)
>>> gen.set(start=2.0 * u.ms, stop=None, origin=1.0 * u.ms)
>>> params = gen.get()
>>> _ = params['rate'], params['stop']
See Also
--------
poisson_generator_ps : Precise-time Poisson generator with dead time.
inhomogeneous_poisson_generator : Piecewise-constant time-varying Poisson rate.
sinusoidal_poisson_generator : Sinusoidally modulated Poisson rate.
References
----------
.. [1] NEST source: ``models/poisson_generator.cpp`` and
``models/poisson_generator.h``.
.. [2] NEST source: ``nestkernel/stimulation_device.h`` and
``nestkernel/stimulation_device.cpp``.
.. [3] NEST model docs:
https://nest-simulator.readthedocs.io/en/stable/models/poisson_generator.html
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size = 1,
rate: ArrayLike = 0. * u.Hz,
start: ArrayLike = 0. * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0. * u.ms,
rng_seed: int = 0,
name: str | None = None,
):
super().__init__(in_size=in_size, name=name)
self.rate = self._to_scalar_rate_hz(rate)
self.start = self._to_scalar_time_ms(start)
self.stop = np.inf if stop is None else self._to_scalar_time_ms(stop)
self.origin = self._to_scalar_time_ms(origin)
self.rng_seed = int(rng_seed)
if self.rate < 0.0:
raise ValueError('The rate cannot be negative.')
if self.stop < self.start:
raise ValueError('stop >= start required.')
self._dt_cache_ms = np.nan
self._t_min_step = 0
self._t_max_step = np.iinfo(np.int64).max
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_timing_cache(dt_ms)
@staticmethod
def _to_scalar_time_ms(value: ArrayLike) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
arr = np.asarray(value.to_decimal(u.ms), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value), dtype=dftype)
if arr.size != 1:
raise ValueError('Time parameters must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _to_scalar_rate_hz(value: ArrayLike) -> float:
dftype = brainstate.environ.dftype()
if isinstance(value, u.Quantity):
arr = np.asarray(value.to_decimal(u.Hz), dtype=dftype)
else:
arr = np.asarray(u.math.asarray(value), dtype=dftype)
if arr.size != 1:
raise ValueError('rate must be scalar.')
return float(arr.reshape(()))
@staticmethod
def _time_to_step(time_ms: float, dt_ms: float) -> int:
return int(np.rint(time_ms / dt_ms))
@staticmethod
def _assert_grid_time(name: str, time_ms: float, dt_ms: float):
if not np.isfinite(time_ms):
return
ratio = time_ms / dt_ms
nearest = np.rint(ratio)
if not math.isclose(ratio, nearest, rel_tol=0.0, abs_tol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
def _dt_ms(self) -> float:
dt = brainstate.environ.get_dt()
return self._to_scalar_time_ms(dt)
def _maybe_dt_ms(self) -> float | None:
dt = brainstate.environ.get('dt', default=None)
if dt is None:
return None
return self._to_scalar_time_ms(dt)
def _current_time_ms(self) -> float:
t = brainstate.environ.get('t', default=0. * u.ms)
if t is None:
return 0.0
return self._to_scalar_time_ms(t)
def _refresh_timing_cache(self, dt_ms: float):
self._assert_grid_time('origin', self.origin, dt_ms)
self._assert_grid_time('start', self.start, dt_ms)
self._assert_grid_time('stop', self.stop, dt_ms)
self._t_min_step = self._time_to_step(self.origin + self.start, dt_ms)
if np.isfinite(self.stop):
self._t_max_step = self._time_to_step(self.origin + self.stop, dt_ms)
else:
self._t_max_step = np.iinfo(np.int64).max
self._dt_cache_ms = float(dt_ms)
def _is_active(self, curr_step: int) -> bool:
return (self._t_min_step < curr_step) and (curr_step <= self._t_max_step)
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Initialize the RNG state used by Poisson sampling.
Parameters
----------
batch_size : int or None, optional
Unused. Present for framework API compatibility with
:class:`brainstate.nn.Dynamics`. Default is ``None``.
**kwargs : Any
Unused keyword arguments accepted for API compatibility.
Notes
-----
:meth:`update` lazily calls this method on the first step if
``init_state`` has not been invoked explicitly. Calling ``init_state``
resets the RNG to the original seed, so repeated calls restart the
stochastic sequence from the beginning.
See Also
--------
poisson_generator.update : Consumes ``rng_key`` populated here.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> from brainpy.state import poisson_generator
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... gen = poisson_generator(in_size=4, rate=800.0 * u.Hz, rng_seed=7)
... gen.init_state()
"""
del batch_size, kwargs
self.rng_key = brainstate.ShortTermState(jax.random.PRNGKey(self.rng_seed))
[docs]
def set(
self,
*,
rate: ArrayLike | object = _UNSET,
start: ArrayLike | object = _UNSET,
stop: ArrayLike | object = _UNSET,
origin: ArrayLike | object = _UNSET,
):
r"""Update public parameters and refresh the timing cache when needed.
Only keyword arguments that are explicitly passed are modified; omitted
arguments retain their current values.
Parameters
----------
rate : ArrayLike or object, optional
New scalar firing rate in spikes/s (Hz). Accepts any
``ArrayLike`` with exactly one element, or a
:class:`saiunit.Quantity` convertible to ``u.Hz``.
Must satisfy ``rate >= 0`` after conversion. Omit to keep the
current value.
start : ArrayLike or object, optional
New scalar relative start time in ms (exclusive lower bound after
adding ``origin``). Must be scalar-convertible and, when ``dt`` is
in the environment, grid-representable. Omit to keep the current
value.
stop : ArrayLike or None or object, optional
New scalar relative stop time in ms (inclusive upper bound after
adding ``origin``). ``None`` maps to ``+inf``. Must satisfy
``stop >= start`` after conversion. Omit to keep the current value.
origin : ArrayLike or object, optional
New scalar time origin offset in ms added to both ``start`` and
``stop``. Must be scalar-convertible and grid-representable when
``dt`` is available. Omit to keep the current value.
Raises
------
ValueError
If ``rate < 0`` after conversion; if ``stop < start`` after
conversion; or if any finite timing parameter is not representable
on the current simulation grid (checked via
:meth:`_assert_grid_time`).
TypeError
If unit conversion to ``u.Hz`` or ``u.ms`` fails for any supplied
value.
See Also
--------
poisson_generator.get : Read-back current parameter values.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import saiunit as u
>>> gen = brainpy.state.poisson_generator(rate=500.0 * u.Hz)
>>> gen.set(rate=1000.0 * u.Hz, stop=50.0 * u.ms)
>>> params = gen.get()
>>> _ = params['rate'], params['stop']
"""
new_rate = self.rate if rate is _UNSET else self._to_scalar_rate_hz(rate)
new_start = self.start if start is _UNSET else self._to_scalar_time_ms(start)
if stop is _UNSET:
new_stop = self.stop
elif stop is None:
new_stop = np.inf
else:
new_stop = self._to_scalar_time_ms(stop)
new_origin = self.origin if origin is _UNSET else self._to_scalar_time_ms(origin)
if new_rate < 0.0:
raise ValueError('The rate cannot be negative.')
if new_stop < new_start:
raise ValueError('stop >= start required.')
self.rate = new_rate
self.start = new_start
self.stop = new_stop
self.origin = new_origin
dt_ms = self._maybe_dt_ms()
if dt_ms is not None:
self._refresh_timing_cache(dt_ms)
[docs]
def get(self) -> dict:
r"""Return current public parameters as scalar SI-compatible values.
Returns
-------
params : dict
Dictionary with four ``float`` entries:
- ``'rate'`` -- firing rate in spikes/s (Hz).
- ``'start'`` -- relative exclusive lower bound in ms.
- ``'stop'`` -- relative inclusive upper bound in ms; ``inf``
when no deactivation time has been set.
- ``'origin'`` -- time origin offset in ms.
Notes
-----
Returned values are plain Python ``float`` scalars (``float64``
precision). They mirror the internal scalar attributes set in
:meth:`__init__` or updated by :meth:`set` and are not bound to any
``saiunit`` quantities.
See Also
--------
poisson_generator.set : Update one or more parameters in place.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import saiunit as u
>>> gen = brainpy.state.poisson_generator(
... rate=800.0 * u.Hz,
... start=5.0 * u.ms,
... stop=100.0 * u.ms,
... origin=2.0 * u.ms,
... )
>>> params = gen.get()
>>> params['rate']
800.0
>>> params['stop']
100.0
"""
return {
'rate': float(self.rate),
'start': float(self.start),
'stop': float(self.stop),
'origin': float(self.origin),
}
def _sample_poisson(self, lam: float) -> jax.Array:
key, subkey = jax.random.split(self.rng_key.value)
self.rng_key.value = key
return jax.random.poisson(
subkey,
lam=lam,
shape=self.varshape,
).astype(np.int64)
[docs]
def update(self):
r"""Advance one simulation step and return per-step spike multiplicities.
Returns
-------
spikes : jax.Array
Integer array with dtype ``int64`` and shape ``self.varshape``.
Each element is the number of spikes emitted by the corresponding
independent output train in the current time step.
- **Active and** ``rate > 0``: entries are i.i.d.
Poisson(:math:`\lambda`) samples with
:math:`\lambda = r \cdot \Delta t / 1000`.
- **Inactive or** ``rate <= 0``: all entries are exactly ``0``.
Raises
------
ValueError
If the timing cache is stale and a finite ``origin``, ``start``,
or ``stop`` is not representable on the current simulation grid
(checked by :meth:`_assert_grid_time`).
KeyError
If ``dt`` is unavailable from ``brainstate.environ.get_dt()`` or
``t`` is expected but cannot be resolved.
Notes
-----
The update proceeds as follows each call:
1. **Lazy init** -- If ``rng_key`` has not been created by
:meth:`init_state`, it is initialized automatically with
``self.rng_seed``.
2. **Cache refresh** -- When ``dt`` changes from the previously cached
value, :meth:`_refresh_timing_cache` recomputes the integer step
bounds :math:`t_{\min}` and :math:`t_{\max}`.
3. **Rate guard** -- If ``rate <= 0``, an all-zero array is returned
without touching the PRNG state.
4. **Activity check** -- The current step index is compared against
the cached step bounds: active iff
:math:`t_{\min,\mathrm{step}} < \mathrm{curr\_step} \le
t_{\max,\mathrm{step}}`. Inactive steps return zeros.
5. **Poisson draw** -- If active, one vectorized sample
``jax.random.poisson(lam, shape=self.varshape)`` is drawn via
:meth:`_sample_poisson`, consuming one PRNG split.
See Also
--------
poisson_generator.init_state : RNG initialization called lazily here.
poisson_generator.set : Update parameters between runs.
"""
if not hasattr(self, 'rng_key'):
self.init_state()
dt_ms = self._dt_ms()
if (not np.isfinite(self._dt_cache_ms)) or (
not math.isclose(dt_ms, self._dt_cache_ms, rel_tol=0.0, abs_tol=1e-15)):
self._refresh_timing_cache(dt_ms)
ditype = brainstate.environ.ditype()
if self.rate <= 0.0:
return jax.numpy.zeros(self.varshape, dtype=ditype)
# JAX-compatible activity check (works under jit / for_loop tracing).
# t may be a traced abstract value inside for_loop, so we avoid float().
t = brainstate.environ.get('t', default=0. * u.ms)
if isinstance(t, u.Quantity):
t_ms_num = t.to_decimal(u.ms)
else:
t_ms_num = jax.numpy.asarray(t)
curr_step = jax.numpy.rint(t_ms_num / dt_ms).astype(jax.numpy.int64)
is_active = (self._t_min_step < curr_step) & (curr_step <= self._t_max_step)
lam = self.rate * dt_ms / 1000.0
spikes = self._sample_poisson(lam)
return jax.numpy.where(is_active, spikes, jax.numpy.zeros(self.varshape, dtype=ditype))