Source code for brainpy_state._nest.spin_detector

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

import math
from dataclasses import dataclass

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'spin_detector',
]


@dataclass
class _StepCalibration:
    dt_ms: float
    t_min_steps: int
    t_max_steps: float


class spin_detector(NESTDevice):
    r"""NEST-compatible detector for binary state decoding from spikes.

    ``spin_detector`` decodes binary activity (``state`` :math:`\in \{0, 1\}`)
    from spike-event multiplicities and stores a chronological event log
    containing ``senders``, ``times``, and decoded ``state`` for every
    emitted event.  The decode logic mirrors NEST
    ``models/spin_detector.{h,cpp}`` with explicit per-event buffering:
    a single provisional event is held in a one-slot buffer and revised
    from state ``0`` to ``1`` before being written whenever a same-sender,
    same-stamp event with multiplicity ``1`` arrives, while multiplicity
    ``2`` events bypass the buffer and are written immediately as state ``1``.

    **1. Event Decoding on a Sender-Time Lattice**

    Let incoming normalized events be
    :math:`e_j=(i_j, s_j, \delta_j, m_j)` with sender :math:`i_j \in \mathbb{N}`,
    step stamp :math:`s_j \in \mathbb{Z}`, offset :math:`\delta_j` (ms), and
    multiplicity :math:`m_j \ge 0`. The detector maintains one buffered tuple
    :math:`b=(i_b, s_b, \delta_b, x_b)` where :math:`x_b \in \{0,1\}` is the
    provisional decoded state.

    For each accepted event in order:

    - If :math:`m_j = 1` and :math:`(i_j, s_j) = (i_b, s_b)`, revise
      :math:`x_b \leftarrow 1` before writing.
    - If a buffer exists, write :math:`b` to output.
    - If :math:`m_j = 2`, write current event immediately with state ``1``
      and clear the buffer.
    - Otherwise, set buffer to current event with provisional state ``0``
      when the buffer is empty; if the buffer is not empty, clear it instead.

    This ordering ensures that a possible ``0 -> 1`` revision is applied
    before the buffered-write emission, exactly as in the NEST C++ reference.

    **2. Time Model and Activity Window**

    With simulation resolution :math:`dt > 0` (ms), current simulation time
    :math:`t`, and step index :math:`n = \mathrm{round}(t/dt)`, the default
    event stamp for events received at step :math:`n` is

    .. math::

       s = n + 1.

    The physical event time in milliseconds is reconstructed as

    .. math::

       t_{\mathrm{event}} = s \cdot dt - \delta.

    Recording is gated on stamps by the half-open interval
    :math:`(s_{\min},\, s_{\max}]` where

    .. math::

       s_{\min} = \frac{\mathrm{origin} + \mathrm{start}}{dt}, \qquad
       s_{\max} = \frac{\mathrm{origin} + \mathrm{stop}}{dt}
       \quad (\text{or } +\infty \text{ when stop is None}),

    so an event is accepted iff :math:`s > s_{\min} \land s \le s_{\max}`.
    The ``start`` bound is exclusive and ``stop`` is inclusive.

    **3. Input Normalization and Multiplicity Inference**

    Runtime ``update`` arrays are flattened to one-dimensional vectors of
    length :math:`N`. Scalars for ``senders``, ``offsets``, and
    ``stamp_steps`` are broadcast to :math:`(N,)`.

    Let :math:`a_j = \mathrm{spikes}[j]`. Per-item event multiplicity
    :math:`c_j` is determined as follows:

    - If ``multiplicities is None`` and all :math:`a_j` are integer-like
      (within ``1e-12`` tolerance):
      :math:`c_j = \max(\mathrm{round}(a_j),\, 0)`.
    - If ``multiplicities is None`` and any :math:`a_j` is non-integer:
      :math:`c_j = \mathbf{1}[a_j > 0]` (binary threshold).
    - If ``multiplicities`` is provided with non-negative integers
      :math:`m_j`: :math:`c_j = m_j \,\mathbf{1}[a_j > 0]`.

    Each event item contributes **at most one decode step** because
    :math:`c_j` is passed as the multiplicity to :meth:`_handle_event`
    rather than used for repeated writes.

    **4. Assumptions, Constraints, and Computational Implications**

    ``dt``, ``t``, ``start``, ``stop`` (when finite), and ``origin`` must be
    scalar-convertible and aligned to the simulation lattice. Alignment is
    verified by round-trip integer checks with ``1e-12`` tolerance.  Per
    :meth:`update` call, normalization is :math:`O(N)` and decoding is
    :math:`O(N)`, with persistent storage cost linear in the total number of
    emitted events :math:`E`.

    Parameters
    ----------
    in_size : Size, optional
        Shape/size metadata consumed by :class:`brainstate.nn.Dynamics`.
        The detector is event-driven and does not return dense tensors, so
        ``in_size`` is retained for API compatibility only. Default is ``1``.
    start : saiunit.Quantity or float, optional
        Scalar relative exclusive lower bound of the recording window,
        convertible to ms. Must be finite and an integer multiple of ``dt``.
        The effective gate is ``stamp_step > (origin + start) / dt``.
        Default is ``0.0 * u.ms``.
    stop : saiunit.Quantity, float, or None, optional
        Scalar relative inclusive upper bound of the recording window,
        convertible to ms. Must be ``None`` or finite and aligned to ``dt``.
        Must satisfy ``stop >= start`` when not ``None``. The effective gate
        is ``stamp_step <= (origin + stop) / dt``. ``None`` means no upper
        bound (:math:`s_{\max} = +\infty`). Default is ``None``.
    origin : saiunit.Quantity or float, optional
        Scalar global time-origin shift added to both ``start`` and ``stop``
        when constructing the active window, convertible to ms. Shifting the
        origin displaces the entire recording window without changing its
        duration. Must be finite and aligned to ``dt``. Default is
        ``0.0 * u.ms``.
    time_in_steps : bool, optional
        Controls the time representation in ``events``. If ``False``,
        ``events['times']`` stores ``float64`` milliseconds computed as
        :math:`s \cdot dt - \delta_j`. If ``True``, ``events['times']``
        stores integer step stamps (``int64``) and ``events['offsets']``
        stores the corresponding ``float64`` offsets in ms. Becomes immutable
        after the first :meth:`update` call. Default is ``False``.
    frozen : bool, optional
        NEST-compatibility flag. ``True`` is unconditionally rejected because
        this device cannot be frozen. Default is ``False``.
    name : str or None, optional
        Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
        Default is ``None``.

    Parameter Mapping
    -----------------
    .. list-table:: Mapping of constructor parameters to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower bound of the recording window.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound of the recording window.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Global shift applied to both window boundaries.
       * - ``time_in_steps``
         - ``False``
         - :math:`\mathrm{repr}_t`
         - Output-time representation: ms float or integer ``(step, offset)`` pair.

    Raises
    ------
    ValueError
        If ``frozen=True``; if ``dt`` is non-positive; if time parameters are
        non-scalar, non-finite where finite values are required, misaligned to
        the simulation step, or violate ``stop >= start``; if ``t`` is not on
        the simulation grid; if ``time_in_steps`` is modified after simulation
        begins; if ``n_events`` is set to any value other than ``0``; if
        provided arrays have inconsistent sizes; if ``spikes``/``offsets``
        contain non-finite values; or if explicit ``multiplicities`` contain
        negative entries.
    TypeError
        If unit conversion or numeric coercion of scalar/array inputs fails.
    KeyError
        If :meth:`get` is called with an unsupported key, or if required
        simulation context values (``'t'`` or ``dt``) are unavailable via
        ``brainstate.environ``.

    Notes
    -----
    - Input events are processed strictly in the order supplied, and one
      buffered event is finalized at the end of every :meth:`update` call.
    - Connection weight and delay do not participate in decode logic.
    - ``time_in_steps`` becomes immutable after the first :meth:`update`
      call that accesses simulation context, matching NEST backend constraints.
    - NEST semantics are defined for multiplicities ``1`` and ``2``. This
      implementation also accepts other non-negative values, which follow the
      ``m != 2`` branch in :meth:`_handle_event`.
    - ``spikes=None`` is a no-op that flushes the buffer and returns the
      current ``events`` without writing any new events.
    - :meth:`init_state` clears all accumulated events and the one-slot
      buffer; it can be used to reset the detector between simulation
      segments without reconstructing the object.

    References
    ----------
    .. [1] NEST Simulator, ``spin_detector`` device.
           https://nest-simulator.readthedocs.io/en/stable/models/spin_detector.html

    Examples
    --------
    Detect binary state for two same-sender, same-stamp events — the second
    event (multiplicity 1, matching sender and stamp) upgrades the state to
    ``1`` before the buffered event is written:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> import numpy as np
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     det = brainpy.state.spin_detector(start=0.0 * u.ms, stop=1.0 * u.ms)
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       dftype = brainstate.environ.dftype()
       ditype = brainstate.environ.ditype()
       ...         _ = det.update(
       ...             spikes=np.array([1.0, 1.0], dtype=dftype),
       ...             senders=np.array([7, 7], dtype=ditype),
       ...             stamp_steps=np.array([1, 1], dtype=ditype),
       ...         )
       ...     ev = det.flush()
       ...     _ = (ev['senders'][0], ev['state'][0])

    Record a multiplicity-2 event with a sub-step offset using
    ``time_in_steps=True``, which splits the timestamp into an integer step
    index and a float offset — multiplicity ``2`` events are written
    immediately with state ``1``:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> import numpy as np
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     det = brainpy.state.spin_detector(time_in_steps=True)
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         _ = det.update(
       ...             spikes=np.array([2.0], dtype=dftype),
       ...             senders=np.array([3], dtype=ditype),
       ...             offsets=np.array([0.02], dtype=dftype) * u.ms,
       ...         )
       ...     ev = det.events
       ...     _ = (ev['times'][0], ev['offsets'][0], ev['state'][0])
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        start: ArrayLike = 0.0 * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0.0 * u.ms,
        time_in_steps: bool = False,
        frozen: bool = False,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        if frozen:
            raise ValueError('spin_detector cannot be frozen.')

        self.start = start
        self.stop = stop
        self.origin = origin

        self._time_in_steps = bool(time_in_steps)
        self._has_been_simulated = False

        self._clear_last_event()
        self.clear_events()

    @property
    def time_in_steps(self) -> bool:
        return self._time_in_steps

    @time_in_steps.setter
    def time_in_steps(self, value: bool):
        if self._has_been_simulated:
            raise ValueError('Property time_in_steps cannot be set after Simulate has been called.')
        self._time_in_steps = bool(value)

    @property
    def n_events(self) -> int:
        return len(self._events_senders)

    @n_events.setter
    def n_events(self, value: int):
        value = int(value)
        if value != 0:
            raise ValueError('Property n_events can only be set to 0 (which clears all stored events).')
        self.clear_events()

    @property
    def events(self) -> dict[str, np.ndarray]:
        ditype = brainstate.environ.ditype()
        dftype = brainstate.environ.dftype()
        out = {
            'senders': np.asarray(self._events_senders, dtype=ditype),
            'state': np.asarray(self._events_state, dtype=ditype),
        }
        if self.time_in_steps:
            out['times'] = np.asarray(self._events_times_steps, dtype=ditype)
            out['offsets'] = np.asarray(self._events_offsets, dtype=dftype)
        else:
            out['times'] = np.asarray(self._events_times_ms, dtype=dftype)
        return out

    def get(self, key: str = 'events'):
        if key == 'events':
            return self.events
        if key == 'n_events':
            return self.n_events
        if key == 'time_in_steps':
            return self.time_in_steps
        raise KeyError(f'Unsupported key "{key}" for spin_detector.get().')

    def clear_events(self):
        self._events_senders: list[int] = []
        self._events_state: list[int] = []
        self._events_times_ms: list[float] = []
        self._events_times_steps: list[int] = []
        self._events_offsets: list[float] = []

[docs] def init_state(self, batch_size: int = None, **kwargs): del batch_size, kwargs self._clear_last_event() self.clear_events()
def connect(self): return None def flush(self): return self.events
[docs] def update( self, spikes: ArrayLike = None, senders: ArrayLike = None, offsets: ArrayLike = None, multiplicities: ArrayLike = None, stamp_steps: ArrayLike = None, ): r"""Decode binary states from spike events for the current simulation step. Reads the current simulation time ``t`` and resolution ``dt`` from ``brainstate.environ``, derives the default stamp step :math:`s = \mathrm{round}(t/dt) + 1`, normalizes the input arrays, applies the activity-window gate, and passes each accepted event through the one-slot decode buffer via :meth:`_handle_event`. After all items are processed, :meth:`_flush_last_event` finalizes any remaining buffered event. Parameters ---------- spikes : ArrayLike or None, optional Input spike payload, flattened to shape ``(N,)``. Accepted dtypes include boolean, integer, and floating-point values. - ``None``: no new events are processed; the buffer is flushed and the current ``events`` dict is returned immediately. - Integer-like values (all within ``1e-12`` of an integer) with ``multiplicities is None``: each element :math:`j` contributes multiplicity :math:`c_j = \max(\mathrm{round}(a_j),\, 0)`. - Non-integer floating values with ``multiplicities is None``: each element contributes :math:`c_j = \mathbf{1}[a_j > 0]` (binary threshold). senders : ArrayLike or None, optional Sender node IDs cast to ``int64``, shape ``(N,)`` or scalar broadcastable to ``(N,)``. Default sender ID is ``1`` for all entries when ``None``. offsets : ArrayLike or None, optional Per-event sub-step timing offsets :math:`\delta_j` in ms, shape ``(N,)`` or scalar broadcastable to ``(N,)``. Values may carry a ``saiunit`` time unit and are converted to ms. Must contain only finite values. Default is ``0.0 * u.ms`` for all entries. multiplicities : ArrayLike or None, optional Explicit non-negative integer event multiplicities cast to ``int64``, shape ``(N,)`` or scalar broadcastable to ``(N,)``. When provided, the integer-like inference path from ``spikes`` is disabled; the count rule becomes :math:`c_j = m_j \,\mathbf{1}[a_j > 0]`. Negative values raise ``ValueError``. Default is ``None``. stamp_steps : ArrayLike or None, optional Explicit integer step stamps :math:`s_j` for each event, cast to ``int64``, shape ``(N,)`` or scalar broadcastable to ``(N,)``. When ``None``, all events are stamped at :math:`s = n + 1` where :math:`n = \mathrm{round}(t/dt)`. Providing custom stamps allows events generated at different simulation steps to be injected in a single call. Default is ``None``. Returns ------- events : dict[str, np.ndarray] Current accumulated events dictionary after processing this step. All arrays are one-dimensional with length :math:`E` equal to the total number of stored events: - ``'senders'`` — ``int64``, shape ``(E,)``. - ``'state'`` — ``int64``, shape ``(E,)``: decoded binary state (:math:`0` or :math:`1`). - ``'times'`` — ``float64`` ms when ``time_in_steps=False``; ``int64`` step stamps when ``time_in_steps=True``. - ``'offsets'`` — ``float64`` ms, shape ``(E,)`` (only present when ``time_in_steps=True``). Raises ------ ValueError If ``t`` is not grid-aligned to ``dt``; if ``start``, ``stop``, or ``origin`` are invalid with respect to ``dt``; if ``dt <= 0``; if provided payload array sizes are incompatible with the ``N`` inferred from ``spikes``; if ``offsets`` contain non-finite values; or if explicit ``multiplicities`` contain negative entries. TypeError If numeric or unit conversion of any payload or time parameter fails. KeyError If required simulation context entries (``'t'`` or ``dt``) are not available via ``brainstate.environ``. Notes ----- Events are stamped at :math:`s = \mathrm{round}(t/dt) + 1` by default and then gated by the active window :math:`(s_{\min},\, s_{\max}]` in step space. Events outside the window are discarded before reaching :meth:`_handle_event`. The one-slot buffer is always flushed at the end of each call regardless of how many new events were processed. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() calib = self._get_step_calibration(dt) step_now = self._time_to_step(t, calib.dt_ms) self._has_been_simulated = True if spikes is not None: spike_arr = self._to_float_array(spikes, name='spikes') n_items = spike_arr.size if n_items > 0: sender_arr = self._to_int_array(senders, name='senders', default=1, size=n_items) offset_arr = self._to_float_array(offsets, name='offsets', default=0.0, size=n_items, unit=u.ms) if multiplicities is None: rounded = np.rint(spike_arr) is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12) if is_integer_like: counts = np.maximum(rounded.astype(np.int64), 0) else: counts = (spike_arr > 0.0).astype(np.int64) else: mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items) if np.any(mult_arr < 0): raise ValueError('multiplicities must be non-negative.') counts = np.where(spike_arr > 0.0, mult_arr, 0) if stamp_steps is None: ditype = brainstate.environ.ditype() stamp_arr = np.full((n_items,), step_now + 1, dtype=ditype) else: stamp_arr = self._to_int_array(stamp_steps, name='stamp_steps', size=n_items) for i in range(n_items): multiplicity = int(counts[i]) if multiplicity <= 0: continue stamp_step = int(stamp_arr[i]) if not self._is_active(stamp_step, calib.t_min_steps, calib.t_max_steps): continue self._handle_event( sender=int(sender_arr[i]), stamp_step=stamp_step, offset_ms=float(offset_arr[i]), multiplicity=multiplicity, dt_ms=calib.dt_ms, ) self._flush_last_event(dt_ms=calib.dt_ms) return self.events
def _handle_event( self, sender: int, stamp_step: int, offset_ms: float, multiplicity: int, dt_ms: float, ): if multiplicity == 1 and sender == self._last_sender and stamp_step == self._last_stamp_step: self._last_state = 1 if self._last_sender != 0: self._write_event( sender=self._last_sender, stamp_step=self._last_stamp_step, offset_ms=self._last_offset_ms, state=self._last_state, dt_ms=dt_ms, ) if multiplicity == 2: self._write_event( sender=sender, stamp_step=stamp_step, offset_ms=offset_ms, state=1, dt_ms=dt_ms, ) self._clear_last_event() else: if self._last_sender == 0: self._last_sender = sender self._last_stamp_step = stamp_step self._last_offset_ms = offset_ms self._last_state = 0 else: self._clear_last_event() def _flush_last_event(self, dt_ms: float): if self._last_sender != 0: self._write_event( sender=self._last_sender, stamp_step=self._last_stamp_step, offset_ms=self._last_offset_ms, state=self._last_state, dt_ms=dt_ms, ) self._clear_last_event() def _write_event( self, sender: int, stamp_step: int, offset_ms: float, state: int, dt_ms: float, ): self._events_senders.append(int(sender)) self._events_state.append(int(state)) if self.time_in_steps: self._events_times_steps.append(int(stamp_step)) self._events_offsets.append(float(offset_ms)) else: self._events_times_ms.append(float(stamp_step) * dt_ms - float(offset_ms)) def _clear_last_event(self): self._last_sender = 0 self._last_stamp_step = 0 self._last_offset_ms = 0.0 self._last_state = 0 @staticmethod def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float: if isinstance(value, u.Quantity): value = u.get_mantissa(value / u.ms) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar time value.') val = float(arr[0]) if (not allow_inf) and (not math.isfinite(val)): raise ValueError(f'{name} must be finite.') return val @classmethod def _to_step_count( cls, value, dt_ms: float, name: str, allow_inf: bool = False, ): if value is None: if allow_inf: return math.inf raise ValueError(f'{name} cannot be None.') ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf) if math.isinf(ms): if allow_inf: return math.inf raise ValueError(f'{name} must be finite.') steps_f = ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must be a multiple of the simulation resolution.') return steps_i def _get_step_calibration(self, dt) -> _StepCalibration: dt_ms = self._to_ms_scalar(dt, name='dt') if dt_ms <= 0.0: raise ValueError('Simulation resolution dt must be positive.') start_steps = self._to_step_count(self.start, dt_ms, 'start') stop_value = math.inf if self.stop is None else self.stop stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True) if not math.isinf(stop_steps) and stop_steps < start_steps: raise ValueError('stop >= start required.') origin_steps = self._to_step_count(self.origin, dt_ms, 'origin') t_min_steps = origin_steps + start_steps t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps return _StepCalibration( dt_ms=dt_ms, t_min_steps=t_min_steps, t_max_steps=t_max_steps, ) def _time_to_step(self, t, dt_ms: float) -> int: t_ms = self._to_ms_scalar(t, name='t') steps_f = t_ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError('Current simulation time t must be aligned to the simulation grid.') return steps_i @staticmethod def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool: if stamp_step <= t_min_steps: return False if math.isinf(t_max_steps): return True return stamp_step <= t_max_steps @staticmethod def _to_float_array( x, name: str, default: float = None, size: int = None, unit=None, ) -> np.ndarray: dftype = brainstate.environ.dftype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=dftype) else: if unit is not None and isinstance(x, u.Quantity): x = x / unit elif isinstance(x, u.Quantity): x = u.get_mantissa(x) arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1) if arr.size == 0 and size is not None: return np.zeros((0,), dtype=dftype) if not np.all(np.isfinite(arr)): raise ValueError(f'{name} must contain finite values.') if size is None: return arr if arr.size == 1 and size > 1: return np.full((size,), arr[0], dtype=dftype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.float64, copy=False) @staticmethod def _to_int_array( x, name: str, default: int = None, size: int = None, ) -> np.ndarray: ditype = brainstate.environ.ditype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=ditype) else: arr = np.asarray(u.math.asarray(x), dtype=ditype).reshape(-1) if size is None: return arr.astype(np.int64, copy=False) if arr.size == 1 and size > 1: return np.full((size,), int(arr[0]), dtype=ditype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.int64, copy=False)