# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from dataclasses import dataclass
from typing import Mapping, Sequence
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'multimeter',
]
@dataclass
class _PendingSample:
stamp_step: int
senders: np.ndarray
values: dict[str, np.ndarray]
@dataclass
class _StepCalibration:
dt_ms: float
interval_steps: int
offset_steps: int
t_min_steps: int
t_max_steps: float
class multimeter(NESTDevice):
r"""NEST-compatible analog recorder for neuron/device state variables.
``multimeter`` records analog state samples from connected targets into an
in-memory ``events`` dictionary compatible with NEST ``multimeter``
semantics. The NEST device-level timing model is reproduced while
exposing a Python update API:
- Sampling times are constrained to a step-grid lattice defined by
``interval`` and ``offset``.
- Recording is gated by a window
:math:`(\mathrm{origin}+\mathrm{start},\;\mathrm{origin}+\mathrm{stop}]`
(start exclusive, stop inclusive) evaluated in simulation steps.
- Samples are enqueued at the current step and emitted on the next
:meth:`update` call (or immediately by :meth:`flush`), reproducing the
one-step request/reply lag used by NEST multimeters.
**1. Step-Grid Sampling Model**
Let :math:`dt` be simulation resolution in ms, and let step index
:math:`n = \mathrm{round}(t/dt)`. During :meth:`update`, sampled values are
stamped at
.. math::
s = n + 1.
Define integer grid parameters
.. math::
m = \frac{\mathrm{interval}}{dt}, \qquad
o = \frac{\mathrm{offset}}{dt}.
A sample is enqueued iff :math:`s` lies on the lattice:
.. math::
s \equiv 0 \ (\mathrm{mod}\ m) \quad \text{if}\ o = 0, \qquad
s \equiv o \ (\mathrm{mod}\ m),\ s \ge o \quad \text{if}\ o > 0.
Both ``interval`` and ``offset`` must be exact integer multiples of
``dt`` (verified to within ``1e-12`` tolerance in floating conversion).
**2. Active Window and Delivery Lag**
With :math:`s_\min = (\mathrm{origin}+\mathrm{start})/dt` and
:math:`s_\max = (\mathrm{origin}+\mathrm{stop})/dt` (or :math:`+\infty`
when ``stop`` is ``None``), a pending sample is written to ``events``
only when
.. math::
s > s_\min \quad \land \quad s \le s_\max.
Because pending samples are emitted before new sampling in each
:meth:`update`, values observed at step :math:`n` become visible in
``events`` at step :math:`n+1` unless :meth:`flush` is called.
**3. Payload Normalization and Shape Constraints**
For each requested recordable ``k`` in ``record_from``, ``data[k]`` is
converted to ``np.float64`` and flattened to shape ``(N,)``. All
recordables must share the same ``N``; scalar payloads (size 1) are
broadcast to ``(N,)`` when another recordable defines ``N > 1``.
``senders`` is converted to ``np.int64`` with the same broadcast rule,
defaulting to ones when omitted. Stored event arrays are one-dimensional
with length equal to the total number of emitted samples across all steps.
**4. Computational Implications**
Per :meth:`update` call with payload size ``N`` and
``R = len(record_from)``, enqueue work is :math:`O(RN)`. Pending
emission is linear in the number of buffered items and the appended event
count. Memory usage grows linearly with total emitted events for
``times``, ``senders``, and each requested recordable channel.
Parameters
----------
in_size : Size, optional
Output size/shape argument consumed by :class:`brainstate.nn.Dynamics`.
This recorder is stateful and returns event dictionaries; ``in_size``
is retained for API consistency only. Default is ``1``.
record_from : Sequence[str], optional
Ordered names of recordable state variables expected as keys in
``data`` during :meth:`update`. If empty, incoming payloads are
silently ignored and no values are stored. Default is ``()``.
interval : saiunit.Quantity or float, optional
Scalar sampling interval in time units convertible to ms
(typically ``u.ms``). Must satisfy ``interval >= dt`` and be an exact
integer multiple of ``dt`` (checked to within ``1e-12`` tolerance).
Default is ``1.0 * u.ms``.
offset : saiunit.Quantity or float, optional
Scalar phase offset of the sampling lattice relative to the simulation
origin, convertible to ms. Must be ``0.0`` or a positive integer
multiple of ``dt``; non-zero offsets shift the first sample to step
:math:`o` and every :math:`m`-th step thereafter.
Default is ``0.0 * u.ms``.
start : saiunit.Quantity or float, optional
Scalar exclusive lower bound of the recording window relative to
``origin``, convertible to ms. A pending sample at stamp step
:math:`s` is discarded when :math:`s \le s_\min`.
Default is ``0.0 * u.ms``.
stop : saiunit.Quantity, float, or None, optional
Scalar inclusive upper bound of the recording window relative to
``origin``, convertible to ms. Must satisfy ``stop >= start`` when
not ``None``. ``None`` means no upper bound
(:math:`s_\max = +\infty`). Default is ``None``.
origin : saiunit.Quantity or float, optional
Scalar global time-origin shift added to both ``start`` and ``stop``
when constructing the active window, convertible to ms. Shifting the
origin displaces the entire recording window without changing its
duration. Default is ``0.0 * u.ms``.
time_in_steps : bool, optional
Controls the unit of ``events['times']``. If ``False``, timestamps
are stored as float milliseconds (``stamp_step * dt``). If ``True``,
timestamps are stored as integer-valued step numbers cast to
``float64``, and an additional ``events['offsets']`` key is emitted
as a zero-filled array of matching shape. Default is ``False``.
frozen : bool, optional
NEST-compatibility flag. ``True`` is unconditionally rejected because
multimeters cannot be frozen. Default is ``False``.
name : str or None, optional
Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
Default is ``None``.
Parameter Mapping
-----------------
.. list-table:: Mapping of constructor parameters to model symbols
:header-rows: 1
:widths: 22 18 22 38
* - Parameter
- Default
- Math symbol
- Semantics
* - ``interval``
- ``1.0 * u.ms``
- :math:`m \cdot dt`
- Sampling period on the simulation step grid.
* - ``offset``
- ``0.0 * u.ms``
- :math:`o \cdot dt`
- Phase shift of the sampling lattice.
* - ``start``
- ``0.0 * u.ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative exclusive lower bound of the activity window.
* - ``stop``
- ``None``
- :math:`t_{\mathrm{stop,rel}}`
- Relative inclusive upper bound of the activity window.
* - ``origin``
- ``0.0 * u.ms``
- :math:`t_0`
- Global time-origin shift for the recording window.
* - ``record_from``
- ``()``
- :math:`\{x_r\}_{r=1}^{R}`
- Ordered recordable channels expected in each payload.
Raises
------
ValueError
If ``frozen=True``; if any timing parameter (``interval``,
``offset``, ``start``, ``stop``, ``origin``, ``dt``) is not
scalar-convertible, not finite when required, not aligned to
``dt``, or violates ordering constraints (e.g. ``interval < dt``
or ``stop < start``); if ``data`` passed to :meth:`update` is not
a mapping; if a required recordable key is absent from ``data``;
if a recordable payload is empty after conversion; or if
recordable/sender lengths are inconsistent after
flattening/broadcasting.
TypeError
If unit conversion or array casting of any time parameter or
payload value to a numeric type fails.
KeyError
If :meth:`get` is called with a key other than ``'events'`` or
``'n_events'``.
Notes
-----
- After the first :meth:`connect` call or the first data-carrying
:meth:`update`, properties ``interval``, ``offset``, and
``record_from`` become immutable and further assignments raise
``ValueError``.
- This recorder does not read neuron states autonomously; the caller
is responsible for extracting state values and passing them via
``data`` in each :meth:`update` call after state integration.
- :meth:`init_state` clears all accumulated events and the pending
buffer; it can be used to reset the recorder between simulation
segments without reconstructing the object.
References
----------
.. [1] NEST Simulator, ``multimeter`` device.
https://nest-simulator.readthedocs.io/en/stable/models/multimeter.html
Examples
--------
Record membrane potential from a single ``iaf_psc_delta`` neuron for
50 steps at 0.1 ms resolution, with the recording window clipped to
the first 5 ms:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import numpy as np
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neuron = brainpy.state.iaf_psc_delta(1, I_e=500.0 * u.pA)
... neuron.init_state()
... mm = brainpy.state.multimeter(
... record_from=['V_m'],
... interval=0.1 * u.ms,
... start=0.0 * u.ms,
... stop=5.0 * u.ms,
... )
... for k in range(50):
... with brainstate.environ.context(t=k * 0.1 * u.ms):
... neuron.update()
... vm = float(neuron.V.value[0] / u.mV)
dftype = brainstate.environ.dftype()
... _ = mm.update(
... {'V_m': np.array([vm], dtype=dftype)},
... senders=np.array([1]),
... )
... events = mm.flush()
... _ = events['V_m'].shape
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size = 1,
record_from: Sequence[str] = (),
interval: ArrayLike = 1.0 * u.ms,
offset: ArrayLike = 0.0 * u.ms,
start: ArrayLike = 0.0 * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0.0 * u.ms,
time_in_steps: bool = False,
frozen: bool = False,
name: str = None,
):
super().__init__(in_size=in_size, name=name)
if frozen:
raise ValueError('multimeter cannot be frozen.')
self._has_targets = False
self._interval = interval
self._offset = offset
self._record_from = ()
self.start = start
self.stop = stop
self.origin = origin
self.time_in_steps = bool(time_in_steps)
self._pending: list[_PendingSample] = []
self.record_from = tuple(record_from)
self.clear_events()
@property
def interval(self):
return self._interval
@interval.setter
def interval(self, value):
if self._has_targets:
raise ValueError(
'The recording interval, the interval offset and the list of '
'properties to record cannot be changed after the multimeter '
'has been connected to nodes.'
)
self._interval = value
@property
def offset(self):
return self._offset
@offset.setter
def offset(self, value):
if self._has_targets:
raise ValueError(
'The recording interval, the interval offset and the list of '
'properties to record cannot be changed after the multimeter '
'has been connected to nodes.'
)
self._offset = value
@property
def record_from(self):
return self._record_from
@record_from.setter
def record_from(self, value):
if self._has_targets:
raise ValueError(
'The recording interval, the interval offset and the list of '
'properties to record cannot be changed after the multimeter '
'has been connected to nodes.'
)
self._record_from = tuple(str(v) for v in value)
self._events_values = {name: [] for name in self._record_from}
self._pending.clear()
@property
def n_events(self) -> int:
return len(self._events_times)
@property
def events(self) -> dict[str, np.ndarray]:
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
out = {
'times': np.asarray(self._events_times, dtype=dftype),
'senders': np.asarray(self._events_senders, dtype=ditype),
}
if self.time_in_steps:
out['offsets'] = np.zeros(out['times'].shape, dtype=dftype)
for key in self._record_from:
out[key] = np.asarray(self._events_values[key], dtype=dftype)
return out
def get(self, key: str = 'events'):
if key == 'events':
return self.events
if key == 'n_events':
return self.n_events
raise KeyError(f'Unsupported key "{key}" for multimeter.get().')
def clear_events(self):
self._events_times: list[float] = []
self._events_senders: list[int] = []
self._events_values = {name: [] for name in self._record_from}
[docs]
def init_state(self, batch_size: int = None, **kwargs):
del batch_size, kwargs
self.clear_events()
self._pending.clear()
def connect(self):
self._has_targets = True
[docs]
def flush(self):
r"""Emit all buffered pending samples and return the current event store.
Reads ``dt`` from :func:`brainstate.environ.get_dt`, validates timing
calibration, converts pending step stamps to output times, and appends
all active samples to the internal event arrays. After the call the
pending buffer is empty.
Returns
-------
events : dict[str, np.ndarray]
Event dictionary identical to :attr:`events`, reflecting all
samples emitted up to and including this call. See the class-level
``Returns`` section for the full description of keys and dtypes.
Raises
------
ValueError
If ``dt`` obtained from the simulation environment is non-positive,
not scalar-convertible, or incompatible with the configured timing
parameters (``interval``, ``offset``, ``start``, ``stop``,
``origin``).
TypeError
If ``dt`` cannot be converted to a scalar ``float`` ms value.
"""
dt = brainstate.environ.get_dt()
calib = self._get_step_calibration(dt)
self._emit_pending(calib)
return self.events
[docs]
def update(
self,
data: Mapping[str, ArrayLike] = None,
senders: ArrayLike = None,
):
r"""Process one simulation step and optionally enqueue a new sample.
Parameters
----------
data : Mapping[str, ArrayLike] or None, optional
Mapping from each name in ``record_from`` to its current analog
value payload. Each payload is converted to ``np.float64`` and
flattened to shape ``(N,)``. Scalars (size 1) are broadcast to
``(N,)`` when another recordable defines ``N > 1``. If ``None``,
no new sample is enqueued and only pending samples are emitted.
Default is ``None``.
senders : ArrayLike or None, optional
Sender IDs associated with the payload. Converted to ``np.int64``
and flattened to shape ``(N,)`` using the same scalar-broadcast rule
as recordables. If ``None``, all sender IDs default to ``1``.
Default is ``None``.
Returns
-------
events : dict[str, np.ndarray]
Event dictionary identical to :attr:`events` after emitting all
pending samples and optionally enqueuing the new payload. See the
class-level ``Returns`` section for the full description of keys
and dtypes.
Raises
------
ValueError
If current simulation time ``t`` is not aligned to the simulation
grid; if timing parameters are incompatible with ``dt``; if
``data`` is not a ``Mapping``; if a required recordable key is
absent from ``data``; if a recordable payload is empty after
conversion; or if recordable/sender lengths are inconsistent after
the scalar-broadcast rule.
TypeError
If conversion of ``t``, ``dt``, or any payload value to a numeric
array fails.
KeyError
If ``brainstate.environ`` does not provide the ``'t'`` or ``dt``
context keys required for step computation.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
calib = self._get_step_calibration(dt)
self._emit_pending(calib)
if data is None:
return self.events
self._has_targets = True
if len(self._record_from) == 0:
return self.events
step_now = self._time_to_step(t, calib.dt_ms)
stamp_step = step_now + 1
if self._should_sample(stamp_step, calib.interval_steps, calib.offset_steps):
self._pending.append(self._pack_sample(stamp_step, data, senders))
return self.events
@staticmethod
def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float:
if isinstance(value, u.Quantity):
value = u.get_mantissa(value / u.ms)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be a scalar time value.')
val = float(arr[0])
if (not allow_inf) and (not math.isfinite(val)):
raise ValueError(f'{name} must be finite.')
return val
@classmethod
def _to_step_count(
cls,
value,
dt_ms: float,
name: str,
allow_inf: bool = False,
):
if value is None:
if allow_inf:
return math.inf
raise ValueError(f'{name} cannot be None.')
ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf)
if math.isinf(ms):
if allow_inf:
return math.inf
raise ValueError(f'{name} must be finite.')
steps_f = ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
return steps_i
def _get_step_calibration(self, dt) -> _StepCalibration:
dt_ms = self._to_ms_scalar(dt, name='dt')
if dt_ms <= 0.0:
raise ValueError('Simulation resolution dt must be positive.')
interval_steps = self._to_step_count(self.interval, dt_ms, 'interval')
if interval_steps < 1:
raise ValueError('The sampling interval must be at least as long as the simulation resolution.')
offset_steps = self._to_step_count(self.offset, dt_ms, 'offset')
if offset_steps != 0 and offset_steps < 1:
raise ValueError(
'The offset for the sampling interval must be at least as long as the simulation resolution.')
start_steps = self._to_step_count(self.start, dt_ms, 'start')
stop_value = math.inf if self.stop is None else self.stop
stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True)
if not math.isinf(stop_steps) and stop_steps < start_steps:
raise ValueError('stop >= start required.')
origin_steps = self._to_step_count(self.origin, dt_ms, 'origin')
t_min_steps = origin_steps + start_steps
t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps
return _StepCalibration(
dt_ms=dt_ms,
interval_steps=interval_steps,
offset_steps=offset_steps,
t_min_steps=t_min_steps,
t_max_steps=t_max_steps,
)
def _time_to_step(self, t, dt_ms: float) -> int:
t_ms = self._to_ms_scalar(t, name='t')
steps_f = t_ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError('Current simulation time t must be aligned to the simulation grid.')
return steps_i
@staticmethod
def _should_sample(stamp_step: int, interval_steps: int, offset_steps: int) -> bool:
if offset_steps == 0:
return (stamp_step % interval_steps) == 0
if stamp_step < offset_steps:
return False
return ((stamp_step - offset_steps) % interval_steps) == 0
@staticmethod
def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool:
if stamp_step <= t_min_steps:
return False
if math.isinf(t_max_steps):
return True
return stamp_step <= t_max_steps
@staticmethod
def _to_float_array(x, name: str) -> np.ndarray:
if isinstance(x, u.Quantity):
x = u.get_mantissa(x)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1)
if arr.size == 0:
raise ValueError(f'Recordable "{name}" must contain at least one value.')
return arr
def _pack_sample(
self,
stamp_step: int,
data: Mapping[str, ArrayLike],
senders: ArrayLike = None,
) -> _PendingSample:
if not isinstance(data, Mapping):
raise ValueError('data must be a mapping from recordable names to values.')
values: dict[str, np.ndarray] = {}
n_items = None
for key in self._record_from:
if key not in data:
raise ValueError(f'Missing recordable "{key}" in data.')
arr = self._to_float_array(data[key], key)
if n_items is None:
n_items = arr.size
elif arr.size == 1 and n_items > 1:
dftype = brainstate.environ.dftype()
arr = np.full((n_items,), arr[0], dtype=dftype)
elif arr.size != n_items:
raise ValueError(f'All recordables must have the same size, got "{key}" with size {arr.size}.')
values[key] = arr
if n_items is None:
n_items = 0
ditype = brainstate.environ.ditype()
if senders is None:
sender_arr = np.ones((n_items,), dtype=ditype)
else:
sender_arr = np.asarray(u.math.asarray(senders), dtype=ditype).reshape(-1)
if sender_arr.size == 1 and n_items > 1:
sender_arr = np.full((n_items,), sender_arr[0], dtype=ditype)
elif sender_arr.size != n_items:
raise ValueError(
f'senders size ({sender_arr.size}) does not match recordable size ({n_items}).'
)
return _PendingSample(
stamp_step=stamp_step,
senders=sender_arr,
values=values,
)
def _emit_pending(self, calib: _StepCalibration):
if len(self._pending) == 0:
return
for sample in self._pending:
if not self._is_active(sample.stamp_step, calib.t_min_steps, calib.t_max_steps):
continue
if self.time_in_steps:
timestamp = float(sample.stamp_step)
else:
timestamp = sample.stamp_step * calib.dt_ms
n_items = sample.senders.size
self._events_times.extend([timestamp] * n_items)
self._events_senders.extend(sample.senders.tolist())
for key in self._record_from:
self._events_values[key].extend(sample.values[key].tolist())
self._pending.clear()