# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from dataclasses import dataclass
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'spike_recorder',
]
@dataclass
class _StepCalibration:
dt_ms: float
t_min_steps: int
t_max_steps: float
class spike_recorder(NESTDevice):
r"""NEST-compatible spike recording device.
``spike_recorder`` accumulates spike events into an in-memory ``events``
dictionary, with timestamping and activity-window semantics matching NEST
``spike_recorder``. The NEST recording-device timing model is reproduced
while exposing a Python batch API:
- Incoming spike arrays are timestamped at step :math:`n + 1` where
:math:`n = \mathrm{round}(t / dt)` is the current simulation step.
- Recording is gated by a window
:math:`(\mathrm{origin} + \mathrm{start},\;\mathrm{origin} + \mathrm{stop}]`
(start exclusive, stop inclusive) evaluated in simulation steps.
- Event writes are immediate — there is no one-step delivery lag,
unlike the request/reply mechanism of ``multimeter``.
**1. Step-Stamp and Physical-Time Model**
Let :math:`dt > 0` be the simulation resolution (ms), and let
:math:`n = \mathrm{round}(t / dt)` be the current step index when
:meth:`update` is called at simulation time :math:`t`. Incoming events are
stamped at
.. math::
s = n + 1,
i.e., spikes are interpreted as generated during :math:`(t,\, t + dt]`. If
per-event offsets :math:`\delta_j` (ms) are provided, the stored physical
event time for item :math:`j` is
.. math::
t_j = s \cdot dt - \delta_j.
With ``time_in_steps=True``, storage is split into integer stamps
``events['times']`` (step index :math:`s`) and continuous offsets
``events['offsets']`` (:math:`\delta_j`, ms), preserving sub-step timing.
**2. Activity-Window Gate on the Step Lattice**
Define step bounds
.. math::
s_{\min} = \frac{\mathrm{origin} + \mathrm{start}}{dt}, \qquad
s_{\max} = \frac{\mathrm{origin} + \mathrm{stop}}{dt}
\quad (\text{or } +\infty \text{ if stop is None}).
The recorder is active for stamp step :math:`s` iff
.. math::
s > s_{\min} \;\land\; s \le s_{\max}.
Therefore, ``start`` is exclusive and ``stop`` is inclusive, exactly as in
NEST recording devices.
**3. Multiplicity Inference and Payload Normalization**
Incoming arrays are flattened to one-dimensional vectors of length
:math:`N`. Scalars are broadcast to :math:`(N,)` for ``senders`` and
``offsets``. Let :math:`x_j` denote ``spikes[j]``:
- If ``multiplicities is None`` and all ``spikes`` are integer-like
(within ``1e-12`` tolerance), event counts are
:math:`c_j = \max(\mathrm{round}(x_j),\, 0)`.
- If ``multiplicities is None`` and ``spikes`` contains non-integer values,
:math:`c_j = \mathbf{1}[x_j > 0]`.
- If ``multiplicities`` is provided with non-negative integers :math:`m_j`,
then :math:`c_j = m_j \,\mathbf{1}[x_j > 0]`.
Each item contributes exactly :math:`c_j` stored events by repetition.
**4. Constraints and Computational Implications**
``start``, ``stop`` (when not ``None``), ``origin``, current ``t``, and
``dt`` must be scalar-convertible and aligned to the simulation grid.
Alignment is enforced by round-trip integer checks with ``1e-12``
tolerance. Per :meth:`update` call, normalization is :math:`O(N)` and event
expansion is :math:`O(E_{\mathrm{new}})` where
:math:`E_{\mathrm{new}} = \sum_j c_j`. Persistent memory usage is linear in
the total number of stored events.
Parameters
----------
in_size : Size, optional
Shape/size argument consumed by :class:`brainstate.nn.Dynamics`. The
recorder returns event dictionaries rather than dense tensors;
``in_size`` is retained for API compatibility only. Default is ``1``.
start : saiunit.Quantity or float, optional
Scalar relative exclusive lower bound of the recording window,
convertible to ms. Must be finite and an integer multiple of ``dt``.
The effective gate is ``stamp_step > (origin + start) / dt``.
Default is ``0.0 * u.ms``.
stop : saiunit.Quantity, float, or None, optional
Scalar relative inclusive upper bound of the recording window,
convertible to ms. Must be ``None`` or finite and aligned to ``dt``.
Must satisfy ``stop >= start`` when not ``None``. The effective gate
is ``stamp_step <= (origin + stop) / dt``. ``None`` means no upper
bound (:math:`s_{\max} = +\infty`). Default is ``None``.
origin : saiunit.Quantity or float, optional
Scalar global time-origin shift added to both ``start`` and ``stop``
when constructing the active window, convertible to ms. Shifting the
origin displaces the entire recording window without changing its
duration. Must be finite and aligned to ``dt``. Default is
``0.0 * u.ms``.
time_in_steps : bool, optional
Controls the time representation in ``events``. If ``False``,
``events['times']`` stores ``float64`` milliseconds computed as
:math:`s \cdot dt - \delta_j`. If ``True``, ``events['times']``
stores integer step stamps (``int64``) and ``events['offsets']``
stores the corresponding ``float64`` offsets in ms. Becomes immutable
after the first :meth:`update` call. Default is ``False``.
frozen : bool, optional
NEST-compatibility flag. ``True`` is unconditionally rejected because
this recorder cannot be frozen. Default is ``False``.
name : str or None, optional
Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
Default is ``None``.
Parameter Mapping
-----------------
.. list-table:: Mapping of constructor parameters to model symbols
:header-rows: 1
:widths: 22 18 22 38
* - Parameter
- Default
- Math symbol
- Semantics
* - ``start``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative exclusive lower bound of the active window.
* - ``stop``
- ``None``
- :math:`t_{\mathrm{stop,rel}}`
- Relative inclusive upper bound of the active window.
* - ``origin``
- ``0.0 * u.ms``
- :math:`t_0`
- Global origin shift applied before window gating.
* - ``time_in_steps``
- ``False``
- :math:`\mathrm{repr}_t`
- Time storage mode: physical ms or integer ``(step, offset)`` pair.
Raises
------
ValueError
If ``frozen=True``; if any time parameter (``start``, ``stop``,
``origin``, ``dt``, or current ``t``) is non-scalar, non-finite when
required, not aligned to ``dt``, or violates ``stop >= start``; if
``time_in_steps`` is modified after :meth:`update` has been called;
if ``n_events`` is assigned a value other than ``0``; if payload
array sizes are incompatible with ``spikes`` length; or if explicit
``multiplicities`` contain negative entries.
TypeError
If unit conversion or numeric casting of any payload or time
parameter fails.
KeyError
If :meth:`get` is called with an unsupported key, or if required
simulation context entries (``'t'`` or ``dt``) are not available via
``brainstate.environ``.
Notes
-----
- Event writes are immediate (no one-step delivery lag), unlike
the request/reply mechanism of ``multimeter``.
- ``time_in_steps`` becomes immutable after the first :meth:`update`
call that accesses simulation context, matching NEST backend
constraints.
- ``spikes=None`` is treated as a no-op update that returns the
current ``events`` without writing any new events.
- :meth:`init_state` clears all accumulated events; it can be used to
reset the recorder between simulation segments without reconstructing
the object.
References
----------
.. [1] NEST Simulator, ``spike_recorder`` device.
https://nest-simulator.readthedocs.io/en/stable/models/spike_recorder.html
Examples
--------
Record spikes from a three-neuron population over a 1 ms window at
0.1 ms resolution, using integer-like spike counts:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import numpy as np
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... sr = brainpy.state.spike_recorder(start=0.0 * u.ms, stop=1.0 * u.ms)
... with brainstate.environ.context(t=0.0 * u.ms):
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
... _ = sr.update(
... spikes=np.array([1.0, 0.0, 2.0], dtype=dftype),
... senders=np.array([3, 4, 5], dtype=ditype),
... )
... ev = sr.flush()
... _ = ev['times'].shape
Record a single spike with a sub-step offset using ``time_in_steps=True``,
which splits the timestamp into an integer step index and a float offset:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import numpy as np
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... sr = brainpy.state.spike_recorder(time_in_steps=True)
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = sr.update(
... spikes=np.array([1.0], dtype=dftype),
... senders=np.array([9], dtype=ditype),
... offsets=np.array([0.03], dtype=dftype) * u.ms,
... )
... ev = sr.events
... _ = (ev['times'][0], ev['offsets'][0])
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size = 1,
start: ArrayLike = 0.0 * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0.0 * u.ms,
time_in_steps: bool = False,
frozen: bool = False,
name: str = None,
):
super().__init__(in_size=in_size, name=name)
if frozen:
raise ValueError('spike_recorder cannot be frozen.')
self.start = start
self.stop = stop
self.origin = origin
self._time_in_steps = bool(time_in_steps)
self._has_been_simulated = False
self.clear_events()
@property
def time_in_steps(self) -> bool:
return self._time_in_steps
@time_in_steps.setter
def time_in_steps(self, value: bool):
if self._has_been_simulated:
raise ValueError('Property time_in_steps cannot be set after Simulate has been called.')
self._time_in_steps = bool(value)
@property
def n_events(self) -> int:
return len(self._events_senders)
@n_events.setter
def n_events(self, value: int):
value = int(value)
if value != 0:
raise ValueError('Property n_events can only be set to 0 (which clears all stored events).')
self.clear_events()
@property
def events(self) -> dict[str, np.ndarray]:
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
out = {
'senders': np.asarray(self._events_senders, dtype=ditype),
}
if self.time_in_steps:
out['times'] = np.asarray(self._events_times_steps, dtype=ditype)
out['offsets'] = np.asarray(self._events_offsets, dtype=dftype)
else:
out['times'] = np.asarray(self._events_times_ms, dtype=dftype)
return out
def get(self, key: str = 'events'):
if key == 'events':
return self.events
if key == 'n_events':
return self.n_events
if key == 'time_in_steps':
return self.time_in_steps
raise KeyError(f'Unsupported key "{key}" for spike_recorder.get().')
def clear_events(self):
self._events_senders: list[int] = []
self._events_times_ms: list[float] = []
self._events_times_steps: list[int] = []
self._events_offsets: list[float] = []
[docs]
def init_state(self, batch_size: int = None, **kwargs):
del batch_size, kwargs
self.clear_events()
def connect(self):
# Kept for API symmetry with multimeter.
return None
def flush(self):
return self.events
[docs]
def update(
self,
spikes: ArrayLike = None,
senders: ArrayLike = None,
offsets: ArrayLike = None,
multiplicities: ArrayLike = None,
):
r"""Record spike events for the current simulation step.
Reads the current simulation time ``t`` and resolution ``dt`` from
``brainstate.environ``, computes the stamp step :math:`s = n + 1`
where :math:`n = \mathrm{round}(t / dt)`, applies the activity-window
gate, expands the spike payload into individual events, and appends
them to the internal buffers.
Parameters
----------
spikes : ArrayLike or None, optional
Input spike payload, flattened to shape ``(N,)``. Accepted dtypes
include boolean, integer, and floating-point values.
- ``None``: no new events are written; current ``events`` dict is
returned immediately.
- Integer-like values (all within ``1e-12`` of an integer) with
``multiplicities is None``: each element :math:`j` contributes
:math:`c_j = \max(\mathrm{round}(x_j),\, 0)` events.
- Non-integer floating values with ``multiplicities is None``:
each element contributes :math:`c_j = \mathbf{1}[x_j > 0]`
events (binary threshold).
senders : ArrayLike or None, optional
Sender node IDs cast to ``int64``, shape ``(N,)`` or scalar
broadcastable to ``(N,)``. Default sender ID is ``1`` for all
entries when ``None``.
offsets : ArrayLike or None, optional
Per-event sub-step timing offsets :math:`\delta_j` in ms, shape
``(N,)`` or scalar broadcastable to ``(N,)``. Values may carry a
``saiunit`` time unit and are converted to ms. Must contain only
finite values. Default is ``0.0 * u.ms`` for all entries.
multiplicities : ArrayLike or None, optional
Explicit non-negative integer event multiplicities cast to
``int64``, shape ``(N,)`` or scalar broadcastable to ``(N,)``.
When provided, the integer-like inference path from ``spikes`` is
disabled; the count rule becomes
:math:`c_j = m_j \,\mathbf{1}[x_j > 0]`. Negative values raise
``ValueError``. Default is ``None``.
Returns
-------
events : dict[str, np.ndarray]
Current accumulated events dictionary after processing this step.
All arrays are one-dimensional with length :math:`E` equal to
the total number of stored events:
- ``'senders'`` — ``int64``, shape ``(E,)``.
- ``'times'`` — ``float64`` ms when ``time_in_steps=False``;
``int64`` step stamps when ``time_in_steps=True``.
- ``'offsets'`` — ``float64`` ms, shape ``(E,)`` (only present
when ``time_in_steps=True``).
Raises
------
ValueError
If ``t`` is not grid-aligned to ``dt``; if ``start``, ``stop``,
or ``origin`` are invalid with respect to ``dt``; if ``dt <= 0``;
if provided payload array sizes are incompatible with the ``N``
inferred from ``spikes``; if ``offsets`` contain non-finite
values; or if explicit ``multiplicities`` contain negative
entries.
TypeError
If numeric or unit conversion of any payload or time parameter
fails.
KeyError
If required simulation context entries (``'t'`` or ``dt``) are
not available via ``brainstate.environ``.
Notes
-----
Events are written at stamp step :math:`s = \mathrm{round}(t / dt) + 1`
and then gated by the active window
:math:`(s_{\min},\, s_{\max}]` in step space. If the current stamp
step falls outside the window, the method returns the unchanged
``events`` dict without writing any new data.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
calib = self._get_step_calibration(dt)
step_now = self._time_to_step(t, calib.dt_ms)
stamp_step = step_now + 1
self._has_been_simulated = True
if spikes is None:
return self.events
spike_arr = self._to_float_array(spikes, name='spikes')
if spike_arr.size == 0:
return self.events
n_items = spike_arr.size
sender_arr = self._to_int_array(senders, name='senders', default=1, size=n_items)
offset_arr = self._to_float_array(offsets, name='offsets', default=0.0, size=n_items, unit=u.ms)
if multiplicities is None:
rounded = np.rint(spike_arr)
is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12)
if is_integer_like:
counts = np.maximum(rounded.astype(np.int64), 0)
else:
counts = (spike_arr > 0.0).astype(np.int64)
else:
mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items)
if np.any(mult_arr < 0):
raise ValueError('multiplicities must be non-negative.')
counts = np.where(spike_arr > 0.0, mult_arr, 0)
if not self._is_active(stamp_step, calib.t_min_steps, calib.t_max_steps):
return self.events
active = counts > 0
if not np.any(active):
return self.events
out_senders = np.repeat(sender_arr[active], counts[active])
out_offsets = np.repeat(offset_arr[active], counts[active])
self._events_senders.extend(out_senders.tolist())
if self.time_in_steps:
ditype = brainstate.environ.ditype()
out_steps = np.full(out_senders.shape, stamp_step, dtype=ditype)
self._events_times_steps.extend(out_steps.tolist())
self._events_offsets.extend(out_offsets.tolist())
else:
out_times_ms = stamp_step * calib.dt_ms - out_offsets
self._events_times_ms.extend(out_times_ms.tolist())
return self.events
@staticmethod
def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float:
if isinstance(value, u.Quantity):
value = u.get_mantissa(value / u.ms)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be a scalar time value.')
val = float(arr[0])
if (not allow_inf) and (not math.isfinite(val)):
raise ValueError(f'{name} must be finite.')
return val
@classmethod
def _to_step_count(
cls,
value,
dt_ms: float,
name: str,
allow_inf: bool = False,
):
if value is None:
if allow_inf:
return math.inf
raise ValueError(f'{name} cannot be None.')
ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf)
if math.isinf(ms):
if allow_inf:
return math.inf
raise ValueError(f'{name} must be finite.')
steps_f = ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
return steps_i
def _get_step_calibration(self, dt) -> _StepCalibration:
dt_ms = self._to_ms_scalar(dt, name='dt')
if dt_ms <= 0.0:
raise ValueError('Simulation resolution dt must be positive.')
start_steps = self._to_step_count(self.start, dt_ms, 'start')
stop_value = math.inf if self.stop is None else self.stop
stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True)
if not math.isinf(stop_steps) and stop_steps < start_steps:
raise ValueError('stop >= start required.')
origin_steps = self._to_step_count(self.origin, dt_ms, 'origin')
t_min_steps = origin_steps + start_steps
t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps
return _StepCalibration(
dt_ms=dt_ms,
t_min_steps=t_min_steps,
t_max_steps=t_max_steps,
)
def _time_to_step(self, t, dt_ms: float) -> int:
t_ms = self._to_ms_scalar(t, name='t')
steps_f = t_ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError('Current simulation time t must be aligned to the simulation grid.')
return steps_i
@staticmethod
def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool:
if stamp_step <= t_min_steps:
return False
if math.isinf(t_max_steps):
return True
return stamp_step <= t_max_steps
@staticmethod
def _to_float_array(
x,
name: str,
default: float = None,
size: int = None,
unit=None,
) -> np.ndarray:
dftype = brainstate.environ.dftype()
if x is None:
if default is None:
raise ValueError(f'{name} cannot be None.')
arr = np.asarray([default], dtype=dftype)
else:
if unit is not None and isinstance(x, u.Quantity):
x = x / unit
elif isinstance(x, u.Quantity):
x = u.get_mantissa(x)
arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1)
if arr.size == 0 and size is not None:
return np.zeros((0,), dtype=dftype)
if not np.all(np.isfinite(arr)):
raise ValueError(f'{name} must contain finite values.')
if size is None:
return arr
if arr.size == 1 and size > 1:
return np.full((size,), arr[0], dtype=dftype)
if arr.size != size:
raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).')
return arr.astype(np.float64, copy=False)
@staticmethod
def _to_int_array(
x,
name: str,
default: int = None,
size: int = None,
) -> np.ndarray:
ditype = brainstate.environ.ditype()
if x is None:
if default is None:
raise ValueError(f'{name} cannot be None.')
arr = np.asarray([default], dtype=ditype)
else:
arr = np.asarray(u.math.asarray(x), dtype=ditype).reshape(-1)
if size is None:
return arr.astype(np.int64, copy=False)
if arr.size == 1 and size > 1:
return np.full((size,), int(arr[0]), dtype=ditype)
if arr.size != size:
raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).')
return arr.astype(np.int64, copy=False)