Source code for brainpy_state._nest.correlospinmatrix_detector

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

import math
from collections import deque
from dataclasses import dataclass

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'correlospinmatrix_detector',
]


@dataclass
class _BinaryPulse:
    t_on: int
    t_off: int
    receptor_channel: int


@dataclass
class _Calibration:
    dt_ms: float
    t_min_steps: int
    t_max_steps: float
    delta_tau_ms: float
    delta_tau_steps: int
    tau_max_ms: float
    tau_max_steps: int
    tau_edge_steps: int
    tstart_ms: float
    tstop_ms: float
    n_channels: int
    n_bins: int
    min_delay_steps: int
    signature: tuple


class correlospinmatrix_detector(NESTDevice):
    r"""NEST-compatible ``correlospinmatrix_detector`` device.

    **1. Overview**

    ``correlospinmatrix_detector`` receives binary-state spike streams from
    multiple receptor channels and accumulates raw auto/cross covariance
    histograms over negative, zero, and positive lags. It mirrors NEST
    ``models/correlospinmatrix_detector.{h,cpp}`` for event decoding,
    pulse finalization, and lag-bin accumulation.

    **2. Binary-State Decoding and Pulse Construction**

    For receptor channel :math:`c`, define a binary state
    :math:`x_c(t)\in\{0,1\}` on integer simulation steps.
    Runtime events are triples :math:`e=(c, t, m)`, where :math:`m` is
    multiplicity.

    Decoding rule (NEST compatible):

    - :math:`m=1` — mark a tentative down-transition.
    - A second event with identical ``(channel, stamp_step)`` or
      :math:`m=2`: confirm up-transition (:math:`x_c\leftarrow 1`) and cancel
      tentative-down handling for that event pair.

    A covariance update is triggered only when a previous tentative
    down-transition becomes confirmed. The finalized pulse is
    :math:`p_i=(i, t_i^{\mathrm{on}}, t_i^{\mathrm{off}})`, where
    :math:`t_i^{\mathrm{on}}` is taken from ``_last_change[i]`` and
    :math:`t_i^{\mathrm{off}}` from the confirmed down-transition stamp.

    **3. Lag-Bin Accumulation Equations**

    Let :math:`\Delta=\Delta_{\tau,\mathrm{steps}}`,
    :math:`H=\tau_{\max,\mathrm{steps}}/\Delta`, and
    :math:`k_0=H` (zero-lag bin index).
    For each finalized pulse :math:`p_i`, iterate retained history pulses
    :math:`p_j=(j, t_j^{\mathrm{on}}, t_j^{\mathrm{off}})`.

    Integer lag offsets are constrained by

    .. math::

       \delta_{\min}=\max\!\bigl(t_j^{\mathrm{on}}-t_i^{\mathrm{off}},
       -\tau_{\max,\mathrm{steps}}\bigr),\qquad
       \delta_{\max}=\min\!\bigl(t_j^{\mathrm{off}}-t_i^{\mathrm{on}},
       \tau_{\max,\mathrm{steps}}\bigr).

    For an offset :math:`\delta`, overlap length in simulation steps is

    .. math::

       L_{ij}(\delta)=
       \min\!\bigl(t_i^{\mathrm{off}},\; t_j^{\mathrm{off}}-\delta\Delta\bigr)
       -\max\!\bigl(t_i^{\mathrm{on}},\; t_j^{\mathrm{on}}-\delta\Delta\bigr).

    If :math:`L_{ij}(\delta)>0`, add this integer duration to
    ``count_covariance`` bins:

    - **Zero lag** (:math:`\delta=0`): update ``(i, j, k_0)`` and, when
      :math:`i\neq j`, the mirrored entry ``(j, i, k_0)``.
    - **Negative offsets** (:math:`\delta<0`): update both mirrored bins
      ``(i, j, k_0-\delta)`` and ``(j, i, k_0+\delta)``.
    - **Positive offsets** (:math:`\delta>0`): update mirrored bins only
      for :math:`i\neq j`, matching NEST triangular-edge conventions.

    ``count_covariance`` has shape ``(N_channels, N_channels, N_bins)`` with

    .. math::

       N_{\mathrm{bins}} = 1 + 2\,\frac{\tau_{\max,\mathrm{steps}}}
                                        {\Delta_{\tau,\mathrm{steps}}},

    and stores overlap lengths in simulation-step units (not milliseconds).

    **4. Windowing, Assumptions, and Constraints**

    Activity filtering follows the half-open interval
    ``(origin + start, origin + stop]`` in step space. Events outside this
    interval are discarded before decoding.

    ``Tstart`` and ``Tstop`` are validated and included in calibration
    signatures for reset behavior, but they do not gate accumulation in
    :meth:`update`, matching current NEST source behavior.

    Calibration constraints:

    - ``dt > 0``.
    - ``start``, ``stop`` (if finite), ``origin``, ``delta_tau``, and
      ``tau_max`` must each align exactly to integer multiples of ``dt``.
    - ``delta_tau`` must be finite and strictly positive.
    - ``tau_max`` must be finite, non-negative, and divisible by
      ``delta_tau``.
    - ``N_channels >= 1``; runtime receptor IDs must be in
      ``[0, N_channels - 1]``.

    **5. Computational Implications**

    For each accepted event, down-transition handling is constant-time, while
    pulse insertion and pulse-to-history correlation are linear in the retained
    queue length :math:`Q`. Memory scales as
    :math:`O\!\left(Q + N_{\mathrm{channels}}^2 N_{\mathrm{bins}}\right)`.
    Queue pruning uses NEST minimum-delay semantics with
    ``min_delay_steps = 1``.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape metadata used by :class:`brainstate.nn.Dynamics`.
        The detector is event-driven and stores internal tensors, so
        ``in_size`` does not change ``count_covariance`` shape.
        Default is ``1``.
    delta_tau : quantity (ms) or float or None, optional
        Lag-bin width :math:`\Delta_\tau`. Unitful ``saiunit`` quantities are
        accepted and converted to ms; bare floats are interpreted as ms.
        Must be finite, strictly positive, and an integer multiple of
        simulation ``dt``. ``None`` auto-selects ``dt``.
        Default is ``None``.
    tau_max : quantity (ms) or float or None, optional
        One-sided lag horizon :math:`\tau_{\max}`. Unitful quantities accepted.
        Must be finite, non-negative, an integer multiple of ``dt``, and an
        exact integer multiple of ``delta_tau``. ``None`` auto-selects
        ``10 * delta_tau``. Default is ``None``.
    Tstart : quantity (ms) or float, optional
        Non-negative scalar lower time bound in ms retained for NEST API
        compatibility. Participates in calibration signature and triggers a
        state reset when changed. Default is ``0.0 * u.ms``.
    Tstop : quantity (ms) or float or None, optional
        Non-negative scalar upper time bound in ms retained for NEST API
        compatibility. ``None`` means :math:`+\infty`. Participates in
        calibration signature and triggers a state reset when changed.
        Default is ``None``.
    N_channels : int or ArrayLike, optional
        Number of receptor channels. Must resolve to a scalar integer ``>= 1``.
        Runtime channel IDs must be in ``[0, N_channels - 1]``.
        Default is ``1``.
    start : quantity (ms) or float, optional
        Exclusive lower bound of the activity window relative to ``origin``
        in ms. Must be scalar-convertible and aligned to simulation ``dt``.
        Default is ``0.0 * u.ms``.
    stop : quantity (ms) or float or None, optional
        Inclusive upper bound of the activity window relative to ``origin``
        in ms. Must be scalar-convertible and aligned to ``dt`` when finite.
        ``None`` means :math:`+\infty`. Default is ``None``.
    origin : quantity (ms) or float, optional
        Activity-window origin shift in ms. The effective activity window
        becomes ``(origin + start, origin + stop]``. Must be
        scalar-convertible and aligned to ``dt``.
        Default is ``0.0 * u.ms``.
    name : str or None, optional
        Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
        If ``None``, a name is auto-generated. Default is ``None``.

    Parameter Mapping
    -----------------

    .. list-table::
       :header-rows: 1
       :widths: 18 16 24 42

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``delta_tau``
         - ``None`` → ``dt``
         - :math:`\Delta_\tau`
         - Lag-bin width; auto-resolved to simulation ``dt`` when omitted.
       * - ``tau_max``
         - ``None`` → ``10 * delta_tau``
         - :math:`\tau_{\max}`
         - One-sided lag horizon; auto-resolved when omitted.
       * - ``Tstart``
         - ``0.0 ms``
         - :math:`T_{\mathrm{start}}`
         - Calibration/reset compatibility parameter (not a runtime gate).
       * - ``Tstop``
         - ``None`` (:math:`+\infty`)
         - :math:`T_{\mathrm{stop}}`
         - Calibration/reset compatibility parameter (not a runtime gate).
       * - ``N_channels``
         - ``1``
         - :math:`N_{\mathrm{channels}}`
         - Number of receptor channels and covariance matrix axes.
       * - ``start``
         - ``0.0 ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower bound of the activity window.
       * - ``stop``
         - ``None`` (:math:`+\infty`)
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound of the activity window.
       * - ``origin``
         - ``0.0 ms``
         - :math:`t_0`
         - Global shift applied to ``start`` and ``stop`` boundaries.

    Raises
    ------
    ValueError
        If scalar parameters are non-scalar, non-finite where finite values
        are required, misaligned to simulation resolution, or violate
        constraints (e.g. ``tau_max % delta_tau != 0``, ``stop < start``,
        ``N_channels < 1``, invalid runtime receptor IDs, negative
        multiplicities, non-finite ``spikes``, or mismatched event-array
        sizes).
    KeyError
        If runtime environment keys such as simulation time ``'t'`` or ``dt``
        are unavailable during calibration or update.

    Notes
    -----
    - Connection delays and weights are ignored, matching NEST.
    - Runtime events are provided through :meth:`update` arrays
      (``spikes``, ``receptor_ports``/``receptor_types``, ``multiplicities``,
      ``stamp_steps``), each scalar-broadcastable to one event axis.
    - Optional ``multiplicities`` emulate NEST ``SpikeEvent`` multiplicity.
    - History pruning uses NEST minimum-delay semantics with
      ``min_delay_steps = 1``.
    - Calibration is cached and reused across steps; it is automatically
      invalidated if ``dt`` or any window parameter changes between calls.

    Examples
    --------
    Two-channel detector with explicit events:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> import numpy as np
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     det = brainpy.state.correlospinmatrix_detector(
       ...         N_channels=2,
       ...         delta_tau=1.0 * u.ms,
       ...         tau_max=3.0 * u.ms,
       ...     )
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         det.update(
       ...             spikes=np.array([1.0, 1.0]),
       ...             receptor_ports=np.array([0, 0]),
       ...             multiplicities=np.array([1, 1]),
       ...             stamp_steps=np.array([1, 1]),
       ...         )
       ...     out = det.flush()
       >>> out['count_covariance'].shape
       (2, 2, 7)

    Default parameters with no input events and explicit state reset:

    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     det = brainpy.state.correlospinmatrix_detector()
       ...     with brainstate.environ.context(t=0.0 * u.ms):
       ...         _ = det.update()  # no events; returns current state
       ...     det.init_state()  # explicit reset

    References
    ----------
    .. [1] NEST Simulator, ``correlospinmatrix_detector`` model.
           https://nest-simulator.readthedocs.io/en/stable/models/correlospinmatrix_detector.html
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        delta_tau: ArrayLike = None,
        tau_max: ArrayLike = None,
        Tstart: ArrayLike = 0.0 * u.ms,
        Tstop: ArrayLike = None,
        N_channels: int = 1,
        start: ArrayLike = 0.0 * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0.0 * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.delta_tau = delta_tau
        self.tau_max = tau_max
        self.Tstart = Tstart
        self.Tstop = Tstop
        self.N_channels = N_channels

        self.start = start
        self.stop = stop
        self.origin = origin

        self._calib: _Calibration | None = None
        self._calib_param_ids: tuple | None = None
        self._incoming: deque[_BinaryPulse] = deque()

        self._last_i = 0
        self._t_last_in_spike = -2 ** 62
        self._tentative_down = False
        self._curr_state = np.zeros((0,), dtype=np.bool_)
        ditype = brainstate.environ.ditype()
        self._last_change = np.zeros((0,), dtype=ditype)
        self._count_covariance = np.zeros((0, 0, 0), dtype=ditype)

        self._ensure_calibrated_from_env_if_available()

    @property
    def count_covariance(self) -> np.ndarray:
        r"""Return accumulated raw covariance histogram.

        Returns
        -------
        np.ndarray
            ``int64`` tensor with shape ``(N_channels, N_channels, N_bins)``.
            Entries are accumulated overlap durations measured in simulation
            steps.
        """
        self._ensure_calibrated_from_env_if_available()
        ditype = brainstate.environ.ditype()
        return np.asarray(self._count_covariance, dtype=ditype)

[docs] def get(self, key: str = 'count_covariance'): r"""Read a detector attribute using NEST-style string keys. Parameters ---------- key : str, optional Attribute name. Supported values are ``'count_covariance'``, ``'delta_tau'``, ``'tau_max'``, ``'Tstart'``, ``'Tstop'``, ``'N_channels'``, ``'start'``, ``'stop'``, and ``'origin'``. Default is ``'count_covariance'``. Returns ------- out : dict Requested value. Time-valued scalars are returned in milliseconds, ``count_covariance`` is returned as ``np.ndarray[int64]``. Raises ------ KeyError If ``key`` is unsupported. ValueError If stored parameter values fail scalar/unit conversion checks. """ if key == 'count_covariance': return self.count_covariance if key == 'delta_tau': self._ensure_calibrated_from_env_if_available() return float(self._calib.delta_tau_ms) if self._calib is not None else None if key == 'tau_max': self._ensure_calibrated_from_env_if_available() return float(self._calib.tau_max_ms) if self._calib is not None else None if key == 'Tstart': return self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True) if key == 'Tstop': stop_val = math.inf if self.Tstop is None else self.Tstop return self._to_ms_scalar(stop_val, name='Tstop', allow_inf=True) if key == 'N_channels': return int(self._to_int_scalar(self.N_channels, name='N_channels')) if key == 'start': return self._to_ms_scalar(self.start, name='start') if key == 'stop': stop_val = math.inf if self.stop is None else self.stop return self._to_ms_scalar(stop_val, name='stop', allow_inf=True) if key == 'origin': return self._to_ms_scalar(self.origin, name='origin') raise KeyError(f'Unsupported key "{key}" for correlospinmatrix_detector.get().')
[docs] def connect(self): r"""Compatibility no-op for device connection API. """ return None
[docs] def flush(self): r"""Return current detector outputs without mutating state. Returns ------- dict[str, np.ndarray] Mapping with one key, ``'count_covariance'``, containing the current ``int64`` covariance tensor. """ return { 'count_covariance': self.count_covariance, }
[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Reset internal detector state. Parameters ---------- batch_size : int or None, optional Unused. Present for :class:`brainstate.nn.Dynamics` compatibility. **kwargs Unused keyword arguments. """ del batch_size, kwargs self._ensure_calibrated_from_env_if_available() self._reset_state()
[docs] def update( self, spikes: ArrayLike = None, receptor_ports: ArrayLike = None, receptor_types: ArrayLike = None, multiplicities: ArrayLike = None, stamp_steps: ArrayLike = None, ): r"""Process one simulation tick worth of incoming events. Parameters ---------- spikes : ArrayLike or None, optional Event-presence values for this call. Scalar or 1-D array of shape ``(n_events,)``. Values ``<= 0`` are ignored. If ``multiplicities is None``, integer-like positive values are used as multiplicities; otherwise positive entries act as an event mask. ``None`` means no new events. receptor_ports : ArrayLike or None, optional Receptor channel IDs, scalar-broadcastable or shape ``(n_events,)``. Required when events target channels other than ``0``. Each value must be an integer in ``[0, N_channels - 1]``. receptor_types : ArrayLike or None, optional Alias of ``receptor_ports`` for NEST naming compatibility. Used only when ``receptor_ports is None``. multiplicities : ArrayLike or None, optional Non-negative integer multiplicities, scalar-broadcastable or shape ``(n_events,)``. If provided, effective multiplicity is ``multiplicities[i]`` for positive ``spikes[i]`` and ``0`` otherwise. ``None`` triggers multiplicity inference from ``spikes``. stamp_steps : ArrayLike or None, optional Integer simulation stamps (1-based step count), scalar-broadcastable or shape ``(n_events,)``. ``None`` uses ``current_step + 1`` for all events. Returns ------- out : jax.Array Mapping ``{'count_covariance': ndarray}``, where the ndarray is ``int64`` with shape ``(N_channels, N_channels, N_bins)``. Raises ------ ValueError If runtime arrays are non-finite, negative where forbidden (for example ``multiplicities``), size-incompatible, or contain unknown receptor channels. KeyError If simulation time ``'t'`` or ``dt`` is unavailable in :mod:`brainstate.environ`. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() calib = self._ensure_calibrated(dt) step_now = self._time_to_step(t, calib.dt_ms) if spikes is None: return self.flush() spike_arr = self._to_float_array(spikes, name='spikes') if spike_arr.size == 0: return self.flush() n_items = spike_arr.size if receptor_ports is None and receptor_types is not None: receptor_ports = receptor_types port_arr = self._to_int_array(receptor_ports, name='receptor_ports', default=0, size=n_items) if multiplicities is None: rounded = np.rint(spike_arr) is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12) if is_integer_like: counts = np.maximum(rounded.astype(np.int64), 0) else: counts = (spike_arr > 0.0).astype(np.int64) else: mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items) if np.any(mult_arr < 0): raise ValueError('multiplicities must be non-negative.') counts = np.where(spike_arr > 0.0, mult_arr, 0) if stamp_steps is None: ditype = brainstate.environ.ditype() stamp_arr = np.full((n_items,), step_now + 1, dtype=ditype) else: stamp_arr = self._to_int_array(stamp_steps, name='stamp_steps', size=n_items) for i in range(n_items): multiplicity = int(counts[i]) if multiplicity <= 0: continue curr_i = int(port_arr[i]) if curr_i < 0 or curr_i >= calib.n_channels: raise ValueError(f'Unknown receptor_type {curr_i} for correlospinmatrix_detector.') stamp_step = int(stamp_arr[i]) if not self._is_active(stamp_step, calib.t_min_steps, calib.t_max_steps): continue self._handle_event(curr_i=curr_i, stamp_step=stamp_step, multiplicity=multiplicity, calib=calib) return self.flush()
def _handle_event( self, curr_i: int, stamp_step: int, multiplicity: int, calib: _Calibration, ): down_transition = False if multiplicity == 1: if curr_i == self._last_i and stamp_step == self._t_last_in_spike: self._curr_state[curr_i] = True self._last_change[curr_i] = int(stamp_step) self._tentative_down = False else: if self._tentative_down: down_transition = True self._tentative_down = True elif multiplicity == 2: self._curr_state[curr_i] = True if self._tentative_down: down_transition = True self._curr_state[self._last_i] = False self._last_change[curr_i] = int(stamp_step) self._tentative_down = False if down_transition: self._process_down_transition(calib=calib) self._last_i = curr_i self._t_last_in_spike = int(stamp_step) def _process_down_transition(self, calib: _Calibration): i = int(self._last_i) t_i_on = int(self._last_change[i]) t_i_off = int(self._t_last_in_spike) t_min_on = t_i_on for n in range(calib.n_channels): if bool(self._curr_state[n]) and int(self._last_change[n]) < t_min_on: t_min_on = int(self._last_change[n]) while len(self._incoming) > 0: if (t_min_on - self._incoming[0].t_off) >= (calib.tau_edge_steps + calib.min_delay_steps): self._incoming.popleft() else: break pulse_i = _BinaryPulse(t_on=t_i_on, t_off=t_i_off, receptor_channel=i) insert_pos = len(self._incoming) for idx, pulse in enumerate(self._incoming): if pulse.t_off > pulse_i.t_off: insert_pos = idx break self._incoming.insert(insert_pos, pulse_i) t0 = calib.tau_max_steps // calib.delta_tau_steps dt = calib.delta_tau_steps for pulse_j in self._incoming: j = int(pulse_j.receptor_channel) t_j_on = int(pulse_j.t_on) t_j_off = int(pulse_j.t_off) delta_ij_min = max(t_j_on - t_i_off, -calib.tau_max_steps) delta_ij_max = min(t_j_off - t_i_on, calib.tau_max_steps) lag = min(t_i_off, t_j_off) - max(t_i_on, t_j_on) if lag > 0: self._count_covariance[i, j, t0] += int(lag) if i != j: self._count_covariance[j, i, t0] += int(lag) delta_start = self._trunc_div(delta_ij_min, dt) for delta in range(delta_start, 0): lag = min(t_i_off, t_j_off - delta * dt) - max(t_i_on, t_j_on - delta * dt) if lag > 0: self._count_covariance[i, j, t0 - delta] += int(lag) self._count_covariance[j, i, t0 + delta] += int(lag) if i != j: delta_end = self._trunc_div(delta_ij_max, dt) for delta in range(1, delta_end + 1): lag = min(t_i_off, t_j_off - delta * dt) - max(t_i_on, t_j_on - delta * dt) if lag > 0: self._count_covariance[i, j, t0 - delta] += int(lag) self._count_covariance[j, i, t0 + delta] += int(lag) self._last_change[i] = int(t_i_off) @staticmethod def _trunc_div(a: int, b: int) -> int: return int(float(a) / float(b)) def _ensure_calibrated_from_env_if_available(self): try: dt = brainstate.environ.get_dt() except KeyError: return self._ensure_calibrated(dt) def _ensure_calibrated(self, dt) -> _Calibration: # Fast path: skip full recomputation when dt and all params are unchanged. # _compute_calibration is expensive (many JAX ops); avoid calling it every step. if self._calib is not None and self._calib_param_ids is not None: dt_ms = self._fast_dt_ms(dt) if dt_ms == self._calib.dt_ms and self._param_ids() == self._calib_param_ids: return self._calib new_calib = self._compute_calibration(dt) if self._calib is None or self._calib.signature != new_calib.signature: self._calib = new_calib self._reset_state() self._calib_param_ids = self._param_ids() return self._calib def _param_ids(self) -> tuple: """Lightweight snapshot for detecting parameter changes via object identity.""" return ( id(self.delta_tau), id(self.tau_max), id(self.Tstart), id(self.Tstop), self.N_channels, id(self.start), id(self.stop), id(self.origin), ) @staticmethod def _fast_dt_ms(dt) -> float: """Extract dt in ms cheaply, avoiding full JAX/numpy pipeline.""" if isinstance(dt, u.Quantity): return float(u.get_mantissa(dt / u.ms)) return float(dt) def _reset_state(self): self._incoming = deque() self._last_i = 0 self._t_last_in_spike = -2 ** 62 self._tentative_down = False ditype = brainstate.environ.ditype() if self._calib is None: self._curr_state = np.zeros((0,), dtype=np.bool_) self._last_change = np.zeros((0,), dtype=ditype) self._count_covariance = np.zeros((0, 0, 0), dtype=ditype) return n_channels = int(self._calib.n_channels) n_bins = int(self._calib.n_bins) self._curr_state = np.zeros((n_channels,), dtype=np.bool_) self._last_change = np.zeros((n_channels,), dtype=ditype) self._count_covariance = np.zeros((n_channels, n_channels, n_bins), dtype=ditype) def _compute_calibration(self, dt) -> _Calibration: dt_ms = self._to_ms_scalar(dt, name='dt') if dt_ms <= 0.0: raise ValueError('Simulation resolution dt must be positive.') start_steps = self._to_step_count(self.start, dt_ms, 'start') stop_value = math.inf if self.stop is None else self.stop stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True) if (not math.isinf(stop_steps)) and (stop_steps < start_steps): raise ValueError('stop >= start required.') origin_steps = self._to_step_count(self.origin, dt_ms, 'origin') t_min_steps = origin_steps + start_steps t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps if self.delta_tau is None: delta_tau_ms = dt_ms else: delta_tau_ms = self._to_ms_scalar(self.delta_tau, name='delta_tau') if (not math.isfinite(delta_tau_ms)) or (delta_tau_ms <= 0.0): raise ValueError('/delta_tau must be positive and finite.') delta_tau_steps = self._to_step_count(delta_tau_ms, dt_ms, 'delta_tau') if self.tau_max is None: tau_max_ms = 10.0 * delta_tau_ms else: tau_max_ms = self._to_ms_scalar(self.tau_max, name='tau_max') if (not math.isfinite(tau_max_ms)) or (tau_max_ms < 0.0): raise ValueError('/tau_max must be finite and non-negative.') tau_max_steps = self._to_step_count(tau_max_ms, dt_ms, 'tau_max') if tau_max_steps % delta_tau_steps != 0: raise ValueError('tau_max must be a multiple of delta_tau.') tstart_ms = self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True) tstop_value = math.inf if self.Tstop is None else self.Tstop tstop_ms = self._to_ms_scalar(tstop_value, name='Tstop', allow_inf=True) if tstart_ms < 0.0: raise ValueError('/Tstart must not be negative.') if tstop_ms < 0.0: raise ValueError('/Tstop must not be negative.') n_channels = self._to_int_scalar(self.N_channels, name='N_channels') if n_channels < 1: raise ValueError('/N_channels can only be larger than zero.') n_bins = int(1 + (2 * tau_max_steps) // delta_tau_steps) min_delay_steps = 1 signature = ( float(dt_ms), int(t_min_steps), float(t_max_steps), float(delta_tau_ms), int(delta_tau_steps), float(tau_max_ms), int(tau_max_steps), float(tstart_ms), float(tstop_ms), int(n_channels), int(n_bins), int(min_delay_steps), ) return _Calibration( dt_ms=float(dt_ms), t_min_steps=int(t_min_steps), t_max_steps=float(t_max_steps), delta_tau_ms=float(delta_tau_ms), delta_tau_steps=int(delta_tau_steps), tau_max_ms=float(tau_max_ms), tau_max_steps=int(tau_max_steps), tau_edge_steps=int(tau_max_steps + delta_tau_steps), tstart_ms=float(tstart_ms), tstop_ms=float(tstop_ms), n_channels=int(n_channels), n_bins=int(n_bins), min_delay_steps=int(min_delay_steps), signature=signature, ) @staticmethod def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float: if isinstance(value, u.Quantity): value = u.get_mantissa(value / u.ms) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar time value.') val = float(arr[0]) if (not allow_inf) and (not math.isfinite(val)): raise ValueError(f'{name} must be finite.') return val @staticmethod def _to_int_scalar(value, name: str) -> int: ditype = brainstate.environ.ditype() arr = np.asarray(u.math.asarray(value), dtype=ditype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar integer value.') return int(arr[0]) @classmethod def _to_step_count( cls, value, dt_ms: float, name: str, allow_inf: bool = False, ): if value is None: if allow_inf: return math.inf raise ValueError(f'{name} cannot be None.') ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf) if math.isinf(ms): if allow_inf: return math.inf raise ValueError(f'{name} must be finite.') steps_f = ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must be a multiple of the simulation resolution.') return steps_i def _time_to_step(self, t, dt_ms: float) -> int: t_ms = self._to_ms_scalar(t, name='t') steps_f = t_ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError('Current simulation time t must be aligned to the simulation grid.') return steps_i @staticmethod def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool: if stamp_step <= t_min_steps: return False if math.isinf(t_max_steps): return True return stamp_step <= t_max_steps @staticmethod def _to_float_array( x, name: str, default: float = None, size: int = None, ) -> np.ndarray: dftype = brainstate.environ.dftype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=dftype) else: if isinstance(x, u.Quantity): x = u.get_mantissa(x) arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1) if arr.size == 0 and size is not None: return np.zeros((0,), dtype=dftype) if not np.all(np.isfinite(arr)): raise ValueError(f'{name} must contain finite values.') if size is None: return arr if arr.size == 1 and size > 1: return np.full((size,), arr[0], dtype=dftype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.float64, copy=False) @staticmethod def _to_int_array( x, name: str, default: int = None, size: int = None, ) -> np.ndarray: ditype = brainstate.environ.ditype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=ditype) else: arr = np.asarray(u.math.asarray(x), dtype=ditype).reshape(-1) if size is None: return arr.astype(np.int64, copy=False) if arr.size == 1 and size > 1: return np.full((size,), int(arr[0]), dtype=ditype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.int64, copy=False)