Source code for brainpy_state._nest.volume_transmitter

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

import math
from dataclasses import dataclass

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'volume_transmitter',
]


@dataclass(frozen=True)
class spikecounter:
    r"""Immutable entry in a neuromodulatory spike-history vector.

    Each instance records one on-grid delivery event produced by
    :class:`volume_transmitter`.  The pseudo-spike inserted by
    :meth:`volume_transmitter.init_state` and after each trigger reset
    carries ``multiplicity=0.0``.

    Attributes
    ----------
    spike_time : float
        On-grid spike time in milliseconds, computed as
        :math:`s \cdot \Delta t` where :math:`s` is the delivery stamp index
        and :math:`\Delta t` is the simulation resolution in ms.
    multiplicity : float
        Summed multiplicity of all spikes assigned to ``spike_time``.
        Always ``>= 0.0``; the pseudo-spike inserted at reset has value
        ``0.0``.
    """

    spike_time: float
    multiplicity: float


@dataclass(frozen=True)
class _StepCalibration:
    r"""Immutable discrete-time calibration record used by :class:`volume_transmitter`.

    Computed once per :meth:`volume_transmitter.update` call from the
    simulation environment's ``dt`` and the transmitter's ``min_delay`` and
    ``deliver_interval`` parameters.

    Attributes
    ----------
    dt_ms : float
        Simulation resolution :math:`\Delta t` in milliseconds, derived from
        ``brainstate.environ.get_dt()`` and converted to ms.  Strictly positive.
    min_delay_steps : int
        ``min_delay`` converted to an integer number of simulation steps via
        :math:`\mathrm{round}(d_{\min} / \Delta t)`.  Must be ``>= 1``.
    delivery_period_steps : int
        Trigger period in steps:
        :math:`T_s = \texttt{deliver\_interval} \times d_{\min,s}`.
        A delivery trigger fires when
        :math:`\mathrm{stamp} \bmod T_s = 0`.
    """

    dt_ms: float
    min_delay_steps: int
    delivery_period_steps: int


class volume_transmitter(NESTDevice):
    r"""NEST-compatible ``volume_transmitter`` support device.

    ``volume_transmitter`` collects neuromodulatory spikes and periodically
    exposes their cumulative spike history to dopamine-modulated synapses
    (e.g. ``stdp_dopamine_synapse``).  It reproduces the NEST ring-buffer
    scheduling, trigger logic, and pseudo-spike reset conventions while
    exposing a Python batch-update API.

    **1. Discrete-Time State**

    Let simulation resolution be :math:`\Delta t` (ms) and define the
    on-grid delivery stamp for the current step index
    :math:`n = \mathrm{round}(t / \Delta t)` as :math:`s = n + 1`.
    Internal mutable state consists of:

    - :math:`P[s]` — pending multiplicity map, ``dict[int, float]``,
      accumulating contributions scheduled for future delivery stamp ``s``.
    - :math:`H` — ordered spike-history list of :class:`spikecounter` entries
      :math:`(t_i, m_i)`, with time in ms and non-negative multiplicity.
    - Delivery metadata: ``last_delivery_spikes``, ``last_delivery_time_ms``,
      and ``delivery_count``.

    Immediately after :meth:`init_state`, :math:`H = [(0.0,\; 0.0)]`
    (a NEST-compatible pseudo-spike with zero multiplicity).

    **2. Update Equations and Trigger Rule**

    For each input item :math:`i` with spike indicator :math:`x_i > 0`,
    an effective count :math:`c_i \ge 0` is accumulated into
    :math:`P[s_i]`, where :math:`s_i` is either ``stamp_steps[i]`` or the
    current stamp :math:`s`.

    At each :meth:`update` call, the pending multiplicity for stamp :math:`s`
    is consumed:

    .. math::

       m_s = P[s] \quad (\text{or } 0 \text{ if absent}),
       \qquad
       t_s = s \cdot \Delta t.

    If :math:`m_s > 0`, append :math:`(t_s,\, m_s)` to :math:`H`.

    The delivery period in steps is

    .. math::

       T_s = k \cdot d_{\min,s},
       \qquad
       d_{\min,s} = \mathrm{round}\!\left(\frac{d_{\min}}{\Delta t}\right),

    where :math:`k = \texttt{deliver\_interval}`.  A delivery trigger fires
    when :math:`s \bmod T_s = 0`.

    On trigger:

    1. Capture :math:`D = H` as the delivered history (spikes at stamp
       :math:`s` are included before the reset).
    2. Store :math:`D` and trigger time :math:`t_s` in delivery metadata.
    3. Increment the delivery counter.
    4. Reset :math:`H = [(t_s,\; 0.0)]` (new pseudo-spike).

    **3. Multiplicity Inference**

    Incoming arrays are flattened to one-dimensional vectors of length
    :math:`N`.  Let :math:`x_j` denote ``spikes[j]``:

    - If ``multiplicities is None`` and all :math:`x_j` are integer-like
      (within ``1e-12`` tolerance):
      :math:`c_j = \max(\mathrm{round}(x_j),\; 0)`.
    - If ``multiplicities is None`` and :math:`x_j` contains non-integer
      values: :math:`c_j = \mathbf{1}[x_j > 0]`.
    - If ``multiplicities`` is provided with non-negative integers
      :math:`m_j`: :math:`c_j = m_j \,\mathbf{1}[x_j > 0]`.

    **4. Assumptions and Constraints**

    - ``deliver_interval`` must be a scalar integer ``>= 1``.
    - ``min_delay`` must be scalar, strictly positive, and an integer multiple
      of ``dt``.
    - Simulation time ``t`` must be aligned to the simulation grid (enforced
      at each :meth:`update` call).
    - If ``stamp_steps`` is provided, every entry must satisfy
      ``stamp_steps[i] >= current_stamp``.

    **5. Computational Implications**

    For :math:`N` incoming items per call, scheduling is :math:`O(N)` via
    dictionary accumulation by target stamp.  History memory grows linearly
    with the number of unique stamped events between consecutive triggers.
    Each trigger resets the history to a single pseudo-spike, bounding
    worst-case growth to one trigger period.

    Parameters
    ----------
    in_size : Size, optional
        Shape/size argument consumed by :class:`brainstate.nn.Dynamics`.
        Stored for API compatibility with other device models; it does not
        affect transmitter state-update logic. Default is ``1``.
    deliver_interval : ArrayLike, optional
        Scalar integer-like value (unitless) specifying the trigger period in
        units of ``min_delay``.  Converted via nearest-integer rounding and
        validated to be ``>= 1``.  Increasing this value reduces how often
        connected synapses receive delivered spike histories.
        Default is ``1``.
    min_delay : saiunit.Quantity or float, optional
        Scalar effective global minimal synaptic delay.  Unitful values are
        converted to ms; plain floats are interpreted as ms.  Must be strictly
        positive and an integer multiple of the simulation ``dt`` at the time
        :meth:`update` is called.  Default is ``1.0 * u.ms``.
    name : str or None, optional
        Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
        Default is ``None``.

    Parameter Mapping
    -----------------
    .. list-table:: Mapping of constructor parameters to model symbols
       :header-rows: 1
       :widths: 24 16 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``deliver_interval``
         - ``1``
         - :math:`k`
         - Number of minimal-delay intervals per delivery trigger.
       * - ``min_delay``
         - ``1.0 * u.ms``
         - :math:`d_{\min}`
         - Effective global minimal synaptic delay for trigger period.
       * - ``dt`` (environment)
         - runtime
         - :math:`\Delta t`
         - Simulation resolution for stamp conversion and ms time computation.
       * - ``delivery_period_steps``
         - runtime
         - :math:`T_s`
         - :math:`k \cdot \mathrm{round}(d_{\min} / \Delta t)`.

    Raises
    ------
    ValueError
        If ``deliver_interval`` is non-scalar, not integer-like, or ``< 1``
        (raised during ``__init__``).
    ValueError
        At :meth:`update` time: if ``dt <= 0``; if ``min_delay`` is not a
        positive integer multiple of ``dt``; if time ``t`` is not
        grid-aligned; if ``multiplicities`` contains negative values; if
        ``stamp_steps`` contains past stamps; or if payload arrays are
        non-integer where integer values are required or have mismatched
        flattened sizes.
    TypeError
        If provided scalar/array inputs cannot be converted by
        ``saiunit`` or NumPy conversion paths.
    KeyError
        At :meth:`update` time, if the simulation context is missing the
        required ``'t'`` or ``dt`` entries (depends on
        :mod:`brainstate.environ` behaviour).

    Notes
    -----
    - :meth:`deliver_spikes` returns the current (undelivered) history vector.
    - :attr:`last_delivery_spikes` stores the history snapshot delivered at
      the most recent trigger.
    - :attr:`last_delivery_time` stores the most recent trigger time in ms.
    - :meth:`update` aggregates multiplicities exactly by delivery step,
      mirroring NEST's internal ring-buffer logic.
    - :meth:`handles_test_event` accepts only receptor type ``0``, matching
      the NEST ``volume_transmitter`` interface.
    - :meth:`set_local_device_id` / :meth:`get_local_device_id` are provided
      for compatibility with NEST's device duplication logic.

    References
    ----------
    .. [1] NEST Simulator, ``volume_transmitter`` model.
           https://github.com/nest/nest-simulator/blob/master/models/volume_transmitter.cpp

    Examples
    --------
    Instantiate a transmitter with a two-step delivery period, inject two
    simultaneous spikes, and advance one more step:

    .. code-block:: python

       >>> import brainstate
       >>> import saiunit as u
       >>> import numpy as np
       >>> from brainpy.state import volume_transmitter
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     vt = volume_transmitter(deliver_interval=2, min_delay=0.3 * u.ms)
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         y0 = vt.update(
       ...             spikes=np.array([1.0, 1.0]),
       ...             multiplicities=np.array([1, 2]),
       ...         )
       ...     with brainstate.environ.context(t=0.5 * u.ms):
       ...         y1 = vt.update()
       ...     _ = (y0['triggered'], y1['triggered'])

    Query transmitter state and delivery metadata via :meth:`get`:

    .. code-block:: python

       >>> import brainstate
       >>> import saiunit as u
       >>> from brainpy.state import volume_transmitter
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     vt = volume_transmitter(deliver_interval=1, min_delay=0.1 * u.ms)
       ...     _ = vt.get('deliver_interval')
       ...     _ = vt.get('spike_history')
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        deliver_interval: ArrayLike = 1,
        min_delay: ArrayLike = 1.0 * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.deliver_interval = int(self._to_int_scalar(deliver_interval, name='deliver_interval'))
        if self.deliver_interval < 1:
            raise ValueError('deliver_interval must be >= 1.')

        self.min_delay = min_delay
        self._local_device_id = 0

        self._pending_multiplicities: dict[int, float] = {}
        self._spikecounter: list[spikecounter] = []
        self._last_delivery_spikes: tuple[spikecounter, ...] = ()
        self._last_delivery_time_ms: float = 0.0
        self._delivery_count = 0

        self.init_state()

    @property
    def local_device_id(self) -> int:
        r"""Local device ID used for NEST-compatible device-duplication logic.

        Returns
        -------
        int
            Current local device ID as a Python ``int``.
        """
        return int(self._local_device_id)

    @property
    def last_delivery_time(self) -> float:
        r"""Most recent delivery trigger time in milliseconds.

        Returns ``0.0`` if no trigger has occurred since :meth:`init_state`.

        Returns
        -------
        float
            Trigger time :math:`t_s = s \cdot \Delta t` in ms at the most
            recent trigger step, or ``0.0`` before the first trigger.
        """
        return float(self._last_delivery_time_ms)

    @property
    def last_delivery_spikes(self) -> tuple[spikecounter, ...]:
        r"""Spike-history tuple delivered at the most recent trigger.

        Returns an empty tuple before the first trigger.

        Returns
        -------
        tuple[spikecounter, ...]
            Immutable copy of the history vector :math:`H` that was captured
            at the most recent delivery trigger, including the pseudo-spike at
            the trigger stamp.  Empty tuple if no trigger has fired yet.
        """
        return tuple(self._last_delivery_spikes)

    @property
    def n_deliveries(self) -> int:
        r"""Number of completed delivery triggers since initialization.

        Returns
        -------
        int
            Count of trigger events that have fired since the last
            :meth:`init_state` call.  Increments only when
            :math:`s \bmod T_s = 0` and the history vector is non-empty.
        """
        return int(self._delivery_count)

[docs] def set_local_device_id(self, ldid: ArrayLike): r"""Set the local device ID from a scalar integer-like value. Parameters ---------- ldid : ArrayLike Scalar integer-like value for the new local device ID. Converted via nearest-integer rounding; non-integer values raise ``ValueError``. Raises ------ ValueError If ``ldid`` is non-scalar or not integer-like. TypeError If ``ldid`` cannot be converted to a numeric array. """ self._local_device_id = int(self._to_int_scalar(ldid, name='local_device_id'))
[docs] def get_local_device_id(self) -> int: r"""Return the current local device ID as a Python ``int``. Returns ------- int Current value of the local device ID. """ return int(self._local_device_id)
[docs] def handles_test_event(self, receptor_type: ArrayLike) -> int: r"""Validate a receptor type identifier and return the accepted index. Mirrors the NEST ``volume_transmitter::handles_test_event`` method. Only receptor type ``0`` is accepted; all other values raise ``ValueError``. Parameters ---------- receptor_type : ArrayLike Scalar integer-like receptor identifier to validate. Returns ------- int Always ``0`` when the receptor type is valid. Raises ------ ValueError If ``receptor_type`` is non-scalar, not integer-like, or not equal to ``0``. TypeError If ``receptor_type`` cannot be converted to a numeric array. """ r = int(self._to_int_scalar(receptor_type, name='receptor_type')) if r != 0: raise ValueError(f'Unknown receptor_type {r} for volume_transmitter.') return 0
[docs] def deliver_spikes(self) -> tuple[spikecounter, ...]: r"""Return the current (undelivered) spike-history vector. The history always contains at least one entry: the pseudo-spike ``spikecounter(t, 0.0)`` inserted by :meth:`init_state` or the most recent trigger reset. Returns ------- tuple[spikecounter, ...] Immutable copy of the current internal history list :math:`H`, ordered chronologically by delivery stamp. """ return tuple(self._spikecounter)
[docs] def get(self, key: str = 'deliver_interval'): r"""Query transmitter parameters and mutable state by string key. Parameters ---------- key : str, optional Selector string. Supported values: - ``'deliver_interval'`` — returns ``int``. - ``'min_delay'`` — returns the stored ``min_delay`` value as passed to the constructor (``saiunit.Quantity`` or ``float``). - ``'local_device_id'`` — returns ``int``. - ``'spike_history'`` — returns ``tuple[spikecounter, ...]`` (same as :meth:`deliver_spikes`). - ``'last_delivery_spikes'`` — returns ``tuple[spikecounter, ...]`` (same as :attr:`last_delivery_spikes`). - ``'last_delivery_time'`` — returns ``float`` ms. - ``'n_deliveries'`` — returns ``int``. Default is ``'deliver_interval'``. Returns ------- int or float or saiunit.Quantity or tuple[spikecounter, ...] The selected value. Type depends on ``key`` as described above. Raises ------ KeyError If ``key`` is not one of the supported selector strings. """ if key == 'deliver_interval': return int(self.deliver_interval) if key == 'min_delay': return self.min_delay if key == 'local_device_id': return int(self._local_device_id) if key == 'spike_history': return self.deliver_spikes() if key == 'last_delivery_spikes': return self.last_delivery_spikes if key == 'last_delivery_time': return self.last_delivery_time if key == 'n_deliveries': return self.n_deliveries raise KeyError(f'Unsupported key "{key}" for volume_transmitter.get().')
[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Reset all queue and history state to NEST-compatible initial conditions. Clears the pending-multiplicity map :math:`P`, resets the spike-history vector :math:`H` to the single pseudo-spike ``spikecounter(0.0, 0.0)``, and resets all delivery metadata (``last_delivery_spikes``, ``last_delivery_time_ms``, ``delivery_count``) to their initial values. Parameters ---------- batch_size : int or None, optional Unused placeholder required by the :class:`brainstate.nn.Dynamics` API. Ignored. Default is ``None``. **kwargs Additional keyword arguments accepted for API compatibility and silently ignored. """ del batch_size, kwargs self._pending_multiplicities.clear() self._spikecounter = [spikecounter(0.0, 0.0)] self._last_delivery_spikes = () self._last_delivery_time_ms = 0.0 self._delivery_count = 0
[docs] def connect(self): r"""No-op compatibility hook matching the NEST device interface. Provided so that code that calls ``connect()`` on NEST devices works without modification when targeting :class:`volume_transmitter`. """ return None
[docs] def flush(self): r"""Return a non-triggering snapshot of the current state. Unlike :meth:`update`, this method does not advance the simulation step, consume pending multiplicities, or reset the history. It is useful for inspecting state between :meth:`update` calls (e.g. at the end of a simulation run). Returns ------- dict Dictionary with the following keys: - ``'triggered'`` — ``False`` (no trigger occurs on flush). - ``'t_trig'`` — ``None`` (no trigger time). - ``'delivered_spikes'`` — empty ``tuple`` (no delivery). - ``'spike_history'`` — ``tuple[spikecounter, ...]``, the current history returned by :meth:`deliver_spikes`. """ return { 'triggered': False, 't_trig': None, 'delivered_spikes': (), 'spike_history': self.deliver_spikes(), }
[docs] def update( self, spikes: ArrayLike = None, multiplicities: ArrayLike = None, stamp_steps: ArrayLike = None, ): r"""Advance transmitter state by one simulation step. Reads the current simulation time ``t`` and resolution ``dt`` from :mod:`brainstate.environ`, schedules incoming spike contributions into the pending map :math:`P`, consumes the current-stamp entry, optionally appends a new history entry, and evaluates the delivery trigger. Parameters ---------- spikes : ArrayLike or None, optional Scalar or 1-D array of spike event indicators/counts for the current call, shape ``(N,)`` after flattening. Unitful inputs are accepted; only the mantissa is used. Multiplicity inference rules: - ``multiplicities is None`` and all values are integer-like (within ``1e-12``): :math:`c_j = \max(\mathrm{round}(x_j),\; 0)`. - ``multiplicities is None`` and values contain non-integers: :math:`c_j = \mathbf{1}[x_j > 0]` (binary threshold). - ``multiplicities`` provided: :math:`c_j = m_j \,\mathbf{1}[x_j > 0]`. ``None`` means no incoming events for this step. multiplicities : ArrayLike or None, optional Scalar or 1-D integer-like array, shape ``(N,)`` matching the flattened size of ``spikes``. Each value must be non-negative. Applied only where ``spikes[j] > 0``; non-positive spike indicators force zero contribution regardless of ``multiplicities[j]``. ``None`` enables implicit inference from ``spikes``. stamp_steps : ArrayLike or None, optional Scalar or 1-D integer-like array, shape ``(N,)`` matching the flattened size of ``spikes``. Values are absolute delivery-stamp indices in step-space and must satisfy ``stamp_steps[j] >= current_stamp`` (past stamps raise ``ValueError``). ``None`` assigns all contributions to the current stamp :math:`s`. Returns ------- dict Dictionary with the following keys: - ``'triggered'`` — ``bool``: whether the current stamp fires the delivery trigger (:math:`s \bmod T_s = 0`). - ``'t_trig'`` — ``float`` ms or ``None``: trigger time :math:`t_s = s \cdot \Delta t` if triggered, else ``None``. - ``'delivered_spikes'`` — ``tuple[spikecounter, ...]``: history :math:`H` captured before the trigger reset, or empty tuple if not triggered. - ``'spike_history'`` — ``tuple[spikecounter, ...]``: current history after all processing (post-reset pseudo-spike if triggered). Raises ------ ValueError If ``dt <= 0``; if ``min_delay`` is not a positive integer multiple of ``dt``; if ``t`` is not grid-aligned to ``dt``; if ``multiplicities`` contains negative values; if ``stamp_steps`` contains stamps earlier than the current stamp; or if any array payload is non-integer where integer values are required. ValueError If ``spikes``, ``multiplicities``, or ``stamp_steps`` are not scalar or 1-D, or have mismatched flattened sizes. TypeError If numeric or unit conversion fails for any payload or environment time value. KeyError If required environment values (``'t'`` or ``dt``) are unavailable from :mod:`brainstate.environ`. Notes ----- Trigger evaluation uses stamp :math:`s = n + 1` (one ahead of the step index) and period :math:`T_s = k \cdot d_{\min,s}`. Spikes stamped exactly at the trigger stamp are included in ``'delivered_spikes'`` before the history reset, matching NEST ordering semantics. Examples -------- Schedule spikes at a future stamp and confirm delivery: .. code-block:: python >>> import brainstate >>> import saiunit as u >>> import numpy as np >>> from brainpy.state import volume_transmitter >>> with brainstate.environ.context(dt=0.1 * u.ms): ... vt = volume_transmitter(deliver_interval=1, min_delay=0.2 * u.ms) ... with brainstate.environ.context(t=0.0 * u.ms): ... out0 = vt.update( ... spikes=np.array([1.0, 1.0, 0.0]), ... multiplicities=np.array([2, 3, 7]), ... stamp_steps=np.array([2, 2, 2]), ... ) ... with brainstate.environ.context(t=0.1 * u.ms): ... out1 = vt.update() ... _ = (out0['triggered'], out1['delivered_spikes']) """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() calib = self._get_step_calibration(dt) step_now = self._time_to_step(t, calib.dt_ms) stamp_now = step_now + 1 self._schedule_incoming( spikes=spikes, multiplicities=multiplicities, stamp_steps=stamp_steps, stamp_now=stamp_now, ) multiplicity = float(self._pending_multiplicities.pop(stamp_now, 0.0)) if multiplicity > 0.0: t_spike = float(stamp_now) * calib.dt_ms self._spikecounter.append(spikecounter(t_spike, multiplicity)) triggered = (stamp_now % calib.delivery_period_steps) == 0 delivered_spikes: tuple[spikecounter, ...] = () t_trig = None if triggered: t_trig = float(stamp_now) * calib.dt_ms if len(self._spikecounter) > 0: delivered_spikes = tuple(self._spikecounter) self._last_delivery_spikes = delivered_spikes self._last_delivery_time_ms = t_trig self._delivery_count += 1 self._spikecounter.clear() self._spikecounter.append(spikecounter(t_trig, 0.0)) return { 'triggered': bool(triggered), 't_trig': t_trig, 'delivered_spikes': delivered_spikes, 'spike_history': self.deliver_spikes(), }
def _schedule_incoming( self, spikes: ArrayLike, multiplicities: ArrayLike, stamp_steps: ArrayLike, stamp_now: int, ): if spikes is None: return spike_arr = self._to_float_array(spikes, name='spikes') if spike_arr.size == 0: return n_items = spike_arr.size if multiplicities is None: rounded = np.rint(spike_arr) is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12) if is_integer_like: counts = np.maximum(rounded.astype(np.int64), 0) else: counts = (spike_arr > 0.0).astype(np.int64) else: mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items) if np.any(mult_arr < 0): raise ValueError('multiplicities must be non-negative.') counts = np.where(spike_arr > 0.0, mult_arr, 0) if stamp_steps is None: ditype = brainstate.environ.ditype() stamp_arr = np.full((n_items,), stamp_now, dtype=ditype) else: stamp_arr = self._to_int_array(stamp_steps, name='stamp_steps', size=n_items) if np.any(stamp_arr < stamp_now): raise ValueError('stamp_steps must be >= current delivery step.') for i in range(n_items): c = int(counts[i]) if c <= 0: continue s = int(stamp_arr[i]) self._pending_multiplicities[s] = float(self._pending_multiplicities.get(s, 0.0) + c) @staticmethod def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float: if isinstance(value, u.Quantity): value = u.get_mantissa(value / u.ms) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar time value.') val = float(arr[0]) if (not allow_inf) and (not math.isfinite(val)): raise ValueError(f'{name} must be finite.') return val @staticmethod def _to_int_scalar(value, name: str) -> int: dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar integer value.') val = float(arr[0]) ival = int(np.rint(val)) if not np.isclose(val, ival, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must be an integer value.') return ival @classmethod def _to_step_count( cls, value, dt_ms: float, name: str, ) -> int: ms = cls._to_ms_scalar(value, name=name) steps_f = ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must be a multiple of the simulation resolution.') return steps_i def _get_step_calibration(self, dt) -> _StepCalibration: dt_ms = self._to_ms_scalar(dt, name='dt') if dt_ms <= 0.0: raise ValueError('Simulation resolution dt must be positive.') min_delay_steps = self._to_step_count(self.min_delay, dt_ms=dt_ms, name='min_delay') if min_delay_steps < 1: raise ValueError('min_delay must be at least one simulation step.') period = int(self.deliver_interval) * int(min_delay_steps) if period < 1: raise ValueError('deliver_interval * min_delay_steps must be >= 1.') return _StepCalibration( dt_ms=dt_ms, min_delay_steps=min_delay_steps, delivery_period_steps=period, ) def _time_to_step(self, t, dt_ms: float) -> int: t_ms = self._to_ms_scalar(t, name='t') steps_f = t_ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError('Current simulation time t must be aligned to the simulation grid.') return steps_i @staticmethod def _to_float_array(x, name: str) -> np.ndarray: if isinstance(x, u.Quantity): x = u.get_mantissa(x) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1) if arr.ndim != 1: raise ValueError(f'{name} must be a scalar or 1D array.') return arr @classmethod def _to_int_array( cls, x, name: str, size: int = None, ) -> np.ndarray: arr = cls._to_float_array(x, name=name) if size is not None and arr.size != size: raise ValueError(f'{name} must have size {size}, got {arr.size}.') rounded = np.rint(arr) if not np.allclose(arr, rounded, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must contain integer values.') return rounded.astype(np.int64)