Source code for brainpy_state._nest.correlomatrix_detector

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

import math
from collections import deque
from dataclasses import dataclass

import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTDevice

__all__ = [
    'correlomatrix_detector',
]


@dataclass
class _Spike:
    timestep: int
    weight: float
    receptor_channel: int


@dataclass
class _Calibration:
    dt_ms: float
    start_step: int
    stop_step: float
    origin_step: int
    t_min_steps: int
    t_max_steps: float
    delta_tau_ms: float
    delta_tau_steps: int
    tau_max_ms: float
    tau_max_steps: int
    tau_edge_steps: float
    tstart_ms: float
    tstop_ms: float
    n_channels: int
    n_bins: int
    min_delay_steps: int
    signature: tuple


class correlomatrix_detector(NESTDevice):
    r"""NEST-compatible ``correlomatrix_detector`` device.

    **1. Overview**

    ``correlomatrix_detector`` receives spikes from ``N_channels`` receptor
    pools and accumulates binned auto/cross-covariance matrices for
    non-negative lags. It mirrors the semantics of the NEST
    ``correlomatrix_detector`` recording device, including dual-window
    filtering (activity window and counting window) and NEST-compatible
    bin-edge and matrix-ordering conventions.

    **2. Event Model and Covariance Tensor Equations**

    Let an accepted event be
    :math:`e=(c, t, m, w)` with receptor channel :math:`c`,
    integer simulation step :math:`t`, multiplicity :math:`m`, and
    connection weight :math:`w`. The queued event weight is
    :math:`\hat{w}=m\cdot w`.

    The detector stores all accepted events in one queue sorted by time.
    For each new accepted event :math:`i`:

    1. Insert :math:`i` into the queue (sorted by ``stamp_step``).
    2. Prune events older than the lag horizon
       :math:`\tau_{\mathrm{edge}}=\tau_{\max}+\Delta_\tau/2`,
       including minimum delay offset.
    3. If :math:`t_i \in [T_{\mathrm{start}}, T_{\mathrm{stop}}]`, update
       covariance bins against every remaining queued event :math:`j`.

    For pair :math:`(i,j)`, define :math:`d=|t_i-t_j|` (in steps).
    Channel ordering follows NEST matrix layout:

    - if :math:`t_i \ge t_j`, write into ``(c_i, c_j, b)``,
    - otherwise write into ``(c_j, c_i, b)``.

    The bin index :math:`b` (step domain) is computed as

    .. math::

       b =
       \begin{cases}
       -\left\lfloor \dfrac{\Delta_{\tau,\mathrm{steps}}/2 - d}
       {\Delta_{\tau,\mathrm{steps}}} \right\rfloor,
       & c_{\mathrm{row}} \le c_{\mathrm{col}} \\
       \left\lfloor \dfrac{\Delta_{\tau,\mathrm{steps}}/2 + d}
       {\Delta_{\tau,\mathrm{steps}}} \right\rfloor,
       & c_{\mathrm{row}} > c_{\mathrm{col}}
       \end{cases}

    and contributes

    .. math::

       \mathrm{cov}[c_{\mathrm{row}}, c_{\mathrm{col}}, b]
       \leftarrow
       \mathrm{cov}[c_{\mathrm{row}}, c_{\mathrm{col}}, b]
       + (m_i w_i)\hat{w}_j,

       \mathrm{count}[c_{\mathrm{row}}, c_{\mathrm{col}}, b]
       \leftarrow
       \mathrm{count}[c_{\mathrm{row}}, c_{\mathrm{col}}, b] + m_i.

    At zero lag, off-diagonal or non-identical-event pairs mirror-update the
    transposed entry, reproducing NEST's symmetric zero-lag handling.

    The number of bins is

    .. math::

       N_{\mathrm{bins}} = 1 + \frac{\tau_{\max,\mathrm{steps}}}
                                    {\Delta_{\tau,\mathrm{steps}}}.

    Output tensor shapes are
    ``(N_channels, N_channels, N_bins)`` where bin ``0`` corresponds to zero
    lag and bin ``k`` to lag :math:`k \cdot \Delta_\tau`.

    **3. Windowing, Assumptions, and Constraints**

    Two windows are applied:

    - *Activity window*:
      :math:`(\mathrm{origin}+\mathrm{start},\ \mathrm{origin}+\mathrm{stop}]`.
      Events outside this interval are discarded and never queued.
    - *Counting window*:
      :math:`[T_{\mathrm{start}},\ T_{\mathrm{stop}}]`. Only accepted events
      in this interval update ``n_events``, ``covariance``, and
      ``count_covariance``.

    Calibration constraints mirror NEST semantics in this implementation:

    - ``dt > 0`` and all finite time parameters are scalar-convertible.
    - ``start``, ``stop`` (when finite), ``origin``, ``delta_tau``, and
      ``tau_max`` must align to integer simulation steps.
    - ``delta_tau`` must be an odd multiple of ``dt``.
    - ``tau_max`` must be a non-negative multiple of ``delta_tau``.
    - ``N_channels >= 1``.

    **4. Computational Implications**

    Per accepted event, insertion is :math:`O(Q)` in queue length and
    correlation updates are :math:`O(Q)` over retained events, so total update
    work scales linearly with the active queue size. Memory usage is
    :math:`O(Q + N_{\mathrm{channels}}^2 \cdot N_{\mathrm{bins}})`.

    Parameters
    ----------
    in_size : Size, optional
        Output size/shape metadata consumed by :class:`brainstate.nn.Dynamics`.
        This detector stores internal tensors and does not emit batch-shaped
        arrays through ``update``. Default is ``1``.
    delta_tau : ArrayLike or None, optional
        Lag bin width :math:`\Delta_\tau` in milliseconds. Accepts a scalar
        float-like value or a ``saiunit`` quantity convertible to ms.
        Must be finite, strictly positive, aligned to ``dt``, and an odd
        multiple of ``dt``. ``None`` resolves to ``5 * dt``.
        Default is ``None``.
    tau_max : ArrayLike or None, optional
        One-sided lag horizon :math:`\tau_{\max}` in milliseconds. Accepts a
        scalar float-like value or a quantity convertible to ms. Must be
        finite, non-negative, aligned to ``dt``, and an exact multiple of
        ``delta_tau``. ``None`` resolves to ``10 * delta_tau``.
        Default is ``None``.
    Tstart : ArrayLike, optional
        Inclusive lower bound of the counting window in milliseconds.
        Must be scalar-convertible; ``saiunit`` quantities are converted
        to ms. Default is ``0.0 * u.ms``.
    Tstop : ArrayLike or None, optional
        Inclusive upper bound of the counting window in milliseconds.
        Must be scalar-convertible when provided. ``None`` means
        :math:`+\infty`. Default is ``None``.
    N_channels : int or ArrayLike, optional
        Number of receptor channels. Must resolve to a scalar integer
        ``>= 1``. Channel IDs accepted at runtime are
        ``0, 1, ..., N_channels - 1``. Default is ``1``.
    start : ArrayLike, optional
        Exclusive lower bound of the activity window relative to ``origin``
        in milliseconds. Must be scalar-convertible and aligned to ``dt``.
        Default is ``0.0 * u.ms``.
    stop : ArrayLike or None, optional
        Inclusive upper bound of the activity window relative to ``origin``
        in milliseconds. Must be scalar-convertible and aligned to ``dt``
        when finite. ``None`` means :math:`+\infty`. Default is ``None``.
    origin : ArrayLike, optional
        Activity-window origin shift in milliseconds. Must be
        scalar-convertible and aligned to ``dt``.
        Default is ``0.0 * u.ms``.
    name : str or None, optional
        Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
        Default is ``None``.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 18 17 24 41

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``delta_tau``
         - ``None``
         - :math:`\Delta_\tau`
         - Lag-bin width; resolved as ``5 * dt`` when omitted.
       * - ``tau_max``
         - ``None``
         - :math:`\tau_{\max}`
         - One-sided lag horizon; resolved as ``10 * delta_tau`` when omitted.
       * - ``Tstart``
         - ``0.0 * u.ms``
         - :math:`T_{\mathrm{start}}`
         - Inclusive start of covariance/count update window.
       * - ``Tstop``
         - ``None``
         - :math:`T_{\mathrm{stop}}`
         - Inclusive end of covariance/count update window.
       * - ``N_channels``
         - ``1``
         - :math:`N_{\mathrm{channels}}`
         - Number of receptor channels and matrix axes.
       * - ``start``
         - ``0.0 * u.ms``
         - :math:`t_{\mathrm{start,rel}}`
         - Relative exclusive lower bound of the activity window.
       * - ``stop``
         - ``None``
         - :math:`t_{\mathrm{stop,rel}}`
         - Relative inclusive upper bound of the activity window.
       * - ``origin``
         - ``0.0 * u.ms``
         - :math:`t_0`
         - Global shift added to ``start`` and ``stop`` boundaries.

    Raises
    ------
    ValueError
        If scalar parameters are invalid (non-scalar, non-finite where finite
        values are required, or misaligned to ``dt``), if consistency
        constraints are violated (e.g., ``delta_tau`` even in steps,
        ``tau_max`` not divisible by ``delta_tau``, ``stop < start``, or
        ``N_channels < 1``), or if runtime event arrays contain invalid
        values/sizes (negative multiplicities, non-finite ``weights``,
        unknown receptor channel, or mismatched vector lengths).
    KeyError
        If runtime environment keys such as simulation time ``'t'`` or
        resolution ``dt`` are unavailable when calibration or update is
        called.

    Notes
    -----
    - Unlike some NEST recording devices, ``n_events`` is read-only here,
      matching ``correlomatrix_detector`` semantics.
    - This implementation uses default NEST kernel minimum delay semantics in
      pruning (``min_delay = 1`` simulation step).
    - Optional ``multiplicities`` emulate NEST ``SpikeEvent`` multiplicity.
    - Runtime event arguments accepted by :meth:`update` are one-dimensional
      scalar-broadcastable arrays over the same event axis:
      ``spikes``, ``receptor_ports``/``receptor_types``, ``weights``,
      ``multiplicities``, and ``stamp_steps``.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> import numpy as np
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     det = brainpy.state.correlomatrix_detector(
       ...         N_channels=2,
       ...         delta_tau=0.5 * u.ms,
       ...         tau_max=2.0 * u.ms,
       ...     )
       ...     det.init_state()
       ...     _ = det.update(
       ...         spikes=np.array([1.0, 1.0]),
       ...         receptor_ports=np.array([0, 1]),
       ...         weights=np.array([1.0, 2.0]),
       ...         stamp_steps=np.array([11, 12]),
       ...     )
       ...     out = det.flush()
       >>> out["covariance"].shape
       (2, 2, 5)
       >>> out["count_covariance"].dtype
       dtype('int64')

    References
    ----------
    .. [1] NEST Simulator, ``correlomatrix_detector`` model.
           https://nest-simulator.readthedocs.io/en/stable/models/correlomatrix_detector.html
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size = 1,
        delta_tau: ArrayLike = None,
        tau_max: ArrayLike = None,
        Tstart: ArrayLike = 0.0 * u.ms,
        Tstop: ArrayLike = None,
        N_channels: int = 1,
        start: ArrayLike = 0.0 * u.ms,
        stop: ArrayLike = None,
        origin: ArrayLike = 0.0 * u.ms,
        name: str = None,
    ):
        super().__init__(in_size=in_size, name=name)

        self.delta_tau = delta_tau
        self.tau_max = tau_max
        self.Tstart = Tstart
        self.Tstop = Tstop
        self.N_channels = N_channels

        self.start = start
        self.stop = stop
        self.origin = origin

        self._calib: _Calibration | None = None
        self._incoming: deque[_Spike] = deque()
        ditype = brainstate.environ.ditype()
        self._n_events = np.zeros((0,), dtype=ditype)
        dftype = brainstate.environ.dftype()
        self._covariance = np.zeros((0, 0, 0), dtype=dftype)
        self._count_covariance = np.zeros((0, 0, 0), dtype=ditype)

        self._ensure_calibrated_from_env_if_available()

    @property
    def n_events(self) -> np.ndarray:
        self._ensure_calibrated_from_env_if_available()
        ditype = brainstate.environ.ditype()
        return np.asarray(self._n_events, dtype=ditype)

    @property
    def covariance(self) -> np.ndarray:
        self._ensure_calibrated_from_env_if_available()
        dftype = brainstate.environ.dftype()
        return np.asarray(self._covariance, dtype=dftype)

    @property
    def count_covariance(self) -> np.ndarray:
        self._ensure_calibrated_from_env_if_available()
        ditype = brainstate.environ.ditype()
        return np.asarray(self._count_covariance, dtype=ditype)

[docs] def get(self, key: str = 'covariance'): r"""Retrieve a named scalar or array from the detector. Parameters ---------- key : str, optional Name of the quantity to retrieve. Supported keys: - ``'covariance'`` — accumulated weighted covariance tensor, shape ``(N_channels, N_channels, N_bins)``, dtype float64. - ``'count_covariance'`` — unweighted spike-count covariance tensor, same shape, dtype int64. - ``'n_events'`` — per-channel accepted event counts, shape ``(N_channels,)``, dtype int64. - ``'delta_tau'`` — calibrated lag-bin width in ms, scalar float or ``None`` if not yet calibrated. - ``'tau_max'`` — calibrated one-sided lag horizon in ms, scalar float or ``None`` if not yet calibrated. - ``'Tstart'`` — counting-window lower bound in ms, scalar float (may be ``-inf``). - ``'Tstop'`` — counting-window upper bound in ms, scalar float (may be ``+inf``). - ``'N_channels'`` — number of receptor channels, scalar int. - ``'start'`` — activity-window lower bound (relative) in ms, scalar float. - ``'stop'`` — activity-window upper bound (relative) in ms, scalar float (may be ``+inf``). - ``'origin'`` — activity-window origin shift in ms, scalar float. Default is ``'covariance'``. Returns ------- value : np.ndarray or float or int or None The requested quantity. Array types match the shapes and dtypes described above; scalar keys return Python numeric scalars. Raises ------ KeyError If ``key`` is not one of the supported strings listed above. ValueError If the underlying parameter is non-scalar or non-convertible during retrieval. """ if key == 'covariance': return self.covariance if key == 'count_covariance': return self.count_covariance if key == 'n_events': return self.n_events if key == 'delta_tau': self._ensure_calibrated_from_env_if_available() return float(self._calib.delta_tau_ms) if self._calib is not None else None if key == 'tau_max': self._ensure_calibrated_from_env_if_available() return float(self._calib.tau_max_ms) if self._calib is not None else None if key == 'Tstart': return self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True) if key == 'Tstop': stop_val = math.inf if self.Tstop is None else self.Tstop return self._to_ms_scalar(stop_val, name='Tstop', allow_inf=True) if key == 'N_channels': return int(self._to_int_scalar(self.N_channels, name='N_channels')) if key == 'start': return self._to_ms_scalar(self.start, name='start') if key == 'stop': stop_val = math.inf if self.stop is None else self.stop return self._to_ms_scalar(stop_val, name='stop', allow_inf=True) if key == 'origin': return self._to_ms_scalar(self.origin, name='origin') raise KeyError(f'Unsupported key "{key}" for correlomatrix_detector.get().')
def connect(self): return None
[docs] def flush(self): r"""Return the current accumulated state as a dictionary. Snapshots all three accumulated arrays without modifying internal state. This is equivalent to calling ``get`` for each of the three primary output keys. Returns ------- out : dict A dictionary with the following keys: - ``'covariance'`` : np.ndarray, shape ``(N_channels, N_channels, N_bins)``, dtype float64. Weighted auto/cross-covariance accumulated since the last ``init_state`` call. - ``'count_covariance'`` : np.ndarray, shape ``(N_channels, N_channels, N_bins)``, dtype int64. Unweighted spike-count covariance accumulated since the last ``init_state`` call. - ``'n_events'`` : np.ndarray, shape ``(N_channels,)``, dtype int64. Total number of accepted events per channel within the counting window since the last ``init_state`` call. """ return { 'covariance': self.covariance, 'count_covariance': self.count_covariance, 'n_events': self.n_events, }
[docs] def init_state(self, batch_size: int = None, **kwargs): r"""Reset accumulated state and recalibrate from the environment. Clears the event queue, zeroes all accumulated arrays (``covariance``, ``count_covariance``, ``n_events``), and recomputes calibration from the current ``brainstate`` environment if ``dt`` is available. Must be called before the first :meth:`update` when running inside a ``brainstate.environ.context``. Parameters ---------- batch_size : int or None, optional Ignored. Accepted for API compatibility with :class:`brainstate.nn.Dynamics`. Default is ``None``. **kwargs Ignored. Accepted for API compatibility. """ del batch_size, kwargs self._ensure_calibrated_from_env_if_available() self._reset_state()
[docs] def update( self, spikes: ArrayLike = None, receptor_ports: ArrayLike = None, receptor_types: ArrayLike = None, weights: ArrayLike = None, multiplicities: ArrayLike = None, stamp_steps: ArrayLike = None, ): r"""Process one batch of incoming spike events and update accumulators. Reads the current simulation time ``'t'`` and resolution ``dt`` from the ``brainstate`` environment, calibrates if necessary, then iterates over each event in the batch. Events outside the activity window are silently discarded. Events inside the counting window update ``covariance``, ``count_covariance``, and ``n_events``. Parameters ---------- spikes : ArrayLike or None, optional 1-D array of spike indicators over a batch of ``n_items`` senders. A value ``> 0`` is treated as a spike. If the array contains integer-like floats, the rounded value is used as multiplicity when ``multiplicities`` is ``None``. ``None`` or empty array causes an immediate return of :meth:`flush` output. receptor_ports : ArrayLike or None, optional 1-D integer array of receptor channel indices, shape ``(n_items,)`` or broadcastable scalar. Values must be in ``[0, N_channels - 1]``. Alias ``receptor_types`` is also accepted; if both are provided, ``receptor_ports`` takes precedence. Default (``None``) maps all events to channel ``0``. receptor_types : ArrayLike or None, optional Alias for ``receptor_ports``. Ignored when ``receptor_ports`` is also provided. weights : ArrayLike or None, optional 1-D float array of connection weights, shape ``(n_items,)`` or broadcastable scalar. Must contain finite values. Default (``None``) uses weight ``1.0`` for all events. multiplicities : ArrayLike or None, optional 1-D non-negative integer array of NEST ``SpikeEvent`` multiplicities, shape ``(n_items,)`` or broadcastable scalar. When ``None``, multiplicities are inferred from ``spikes``: integer-like spike values are used directly; non-integer spike values are binarized to ``0`` or ``1``. stamp_steps : ArrayLike or None, optional 1-D integer array of simulation step stamps for each event, shape ``(n_items,)`` or broadcastable scalar. When ``None``, all events are stamped at ``step_now + 1`` (next step), matching NEST's default delivery delay of one step. Returns ------- out : dict Same mapping as :meth:`flush`: ``{'covariance': ..., 'count_covariance': ..., 'n_events': ...}``. Raises ------ ValueError If any of the following occur: - ``multiplicities`` contains negative values. - ``weights`` contains non-finite values. - ``receptor_ports`` contains a channel index outside ``[0, N_channels - 1]``. - Any size-mismatched pair of ``(spikes, receptor_ports)``, ``(spikes, weights)``, ``(spikes, multiplicities)``, or ``(spikes, stamp_steps)`` where neither has size ``1``. KeyError If the ``brainstate`` environment does not expose ``'t'`` or ``dt`` at call time. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() calib = self._ensure_calibrated(dt) step_now = self._time_to_step(t, calib.dt_ms) if spikes is None: return self.flush() spike_arr = self._to_float_array(spikes, name='spikes') if spike_arr.size == 0: return self.flush() n_items = spike_arr.size if receptor_ports is None and receptor_types is not None: receptor_ports = receptor_types port_arr = self._to_int_array(receptor_ports, name='receptor_ports', default=0, size=n_items) weight_arr = self._to_float_array(weights, name='weights', default=1.0, size=n_items) if multiplicities is None: rounded = np.rint(spike_arr) is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12) if is_integer_like: counts = np.maximum(rounded.astype(np.int64), 0) else: counts = (spike_arr > 0.0).astype(np.int64) else: mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items) if np.any(mult_arr < 0): raise ValueError('multiplicities must be non-negative.') counts = np.where(spike_arr > 0.0, mult_arr, 0) if stamp_steps is None: ditype = brainstate.environ.ditype() stamp_arr = np.full((n_items,), step_now + 1, dtype=ditype) else: stamp_arr = self._to_int_array(stamp_steps, name='stamp_steps', size=n_items) for i in range(n_items): multiplicity = int(counts[i]) if multiplicity <= 0: continue sender = int(port_arr[i]) if sender < 0 or sender >= calib.n_channels: raise ValueError(f'Unknown receptor_type {sender} for correlomatrix_detector.') stamp_step = int(stamp_arr[i]) if not self._is_active(stamp_step, calib.t_min_steps, calib.t_max_steps): continue self._handle_event( sender=sender, stamp_step=stamp_step, weight=float(weight_arr[i]), multiplicity=multiplicity, calib=calib, ) return self.flush()
def _handle_event( self, sender: int, stamp_step: int, weight: float, multiplicity: int, calib: _Calibration, ): own_weight = float(multiplicity) * float(weight) spike_i = _Spike( timestep=stamp_step, weight=own_weight, receptor_channel=sender, ) insert_pos = len(self._incoming) for idx, old_spike in enumerate(self._incoming): if old_spike.timestep > stamp_step: insert_pos = idx break self._incoming.insert(insert_pos, spike_i) while len(self._incoming) > 0: dt_steps = stamp_step - self._incoming[0].timestep if dt_steps >= calib.tau_edge_steps + calib.min_delay_steps: self._incoming.popleft() else: break stamp_ms = float(stamp_step) * calib.dt_ms if not self._is_in_count_window(stamp_ms, calib.tstart_ms, calib.tstop_ms): return self._n_events[sender] += 1 n_bins = self._covariance.shape[2] for spike_j in self._incoming: other = spike_j.receptor_channel diff_steps = stamp_step - spike_j.timestep abs_diff_steps = abs(diff_steps) if stamp_step < spike_j.timestep: sender_ind = other other_ind = sender else: sender_ind = sender other_ind = other if sender_ind <= other_ind: bin_index = int( -math.floor( (0.5 * calib.delta_tau_steps - abs_diff_steps) / calib.delta_tau_steps ) ) else: bin_index = int( math.floor( (0.5 * calib.delta_tau_steps + abs_diff_steps) / calib.delta_tau_steps ) ) if bin_index >= n_bins: continue contribution = own_weight * spike_j.weight self._covariance[sender_ind, other_ind, bin_index] += contribution self._count_covariance[sender_ind, other_ind, bin_index] += multiplicity if bin_index == 0 and (diff_steps != 0 or other != sender): self._covariance[other_ind, sender_ind, bin_index] += contribution self._count_covariance[other_ind, sender_ind, bin_index] += multiplicity def _ensure_calibrated_from_env_if_available(self): try: dt = brainstate.environ.get_dt() except KeyError: return self._ensure_calibrated(dt) def _ensure_calibrated(self, dt) -> _Calibration: new_calib = self._compute_calibration(dt) if self._calib is None or self._calib.signature != new_calib.signature: self._calib = new_calib self._reset_state() return self._calib def _reset_state(self): self._incoming = deque() ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() if self._calib is None: self._n_events = np.zeros((0,), dtype=ditype) self._covariance = np.zeros((0, 0, 0), dtype=dftype) self._count_covariance = np.zeros((0, 0, 0), dtype=ditype) return n_channels = int(self._calib.n_channels) n_bins = int(self._calib.n_bins) self._n_events = np.zeros((n_channels,), dtype=ditype) self._covariance = np.zeros((n_channels, n_channels, n_bins), dtype=dftype) self._count_covariance = np.zeros((n_channels, n_channels, n_bins), dtype=ditype) def _compute_calibration(self, dt) -> _Calibration: dt_ms = self._to_ms_scalar(dt, name='dt') if dt_ms <= 0.0: raise ValueError('Simulation resolution dt must be positive.') start_steps = self._to_step_count(self.start, dt_ms, 'start') stop_value = math.inf if self.stop is None else self.stop stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True) if not math.isinf(stop_steps) and stop_steps < start_steps: raise ValueError('stop >= start required.') origin_steps = self._to_step_count(self.origin, dt_ms, 'origin') t_min_steps = origin_steps + start_steps t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps if self.delta_tau is None: delta_tau_ms = 5.0 * dt_ms else: delta_tau_ms = self._to_ms_scalar(self.delta_tau, name='delta_tau') if not math.isfinite(delta_tau_ms) or delta_tau_ms <= 0.0: raise ValueError('delta_tau must be positive and finite.') delta_tau_steps = self._to_step_count(delta_tau_ms, dt_ms, 'delta_tau') if delta_tau_steps % 2 != 1: raise ValueError('/delta_tau must be odd multiple of resolution.') if self.tau_max is None: tau_max_ms = 10.0 * delta_tau_ms else: tau_max_ms = self._to_ms_scalar(self.tau_max, name='tau_max') if not math.isfinite(tau_max_ms) or tau_max_ms < 0.0: raise ValueError('tau_max must be finite and non-negative.') tau_max_steps = self._to_step_count(tau_max_ms, dt_ms, 'tau_max') if tau_max_steps % delta_tau_steps != 0: raise ValueError('tau_max must be a multiple of delta_tau.') tstart_ms = self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True) tstop_value = math.inf if self.Tstop is None else self.Tstop tstop_ms = self._to_ms_scalar(tstop_value, name='Tstop', allow_inf=True) n_channels = self._to_int_scalar(self.N_channels, name='N_channels') if n_channels < 1: raise ValueError('/N_channels can only be larger than zero.') n_bins = int(1 + tau_max_steps // delta_tau_steps) min_delay_steps = 1 signature = ( float(dt_ms), int(start_steps), float(stop_steps), int(origin_steps), int(t_min_steps), float(t_max_steps), float(delta_tau_ms), int(delta_tau_steps), float(tau_max_ms), int(tau_max_steps), float(tstart_ms), float(tstop_ms), int(n_channels), int(n_bins), int(min_delay_steps), ) return _Calibration( dt_ms=float(dt_ms), start_step=int(start_steps), stop_step=float(stop_steps), origin_step=int(origin_steps), t_min_steps=int(t_min_steps), t_max_steps=float(t_max_steps), delta_tau_ms=float(delta_tau_ms), delta_tau_steps=int(delta_tau_steps), tau_max_ms=float(tau_max_ms), tau_max_steps=int(tau_max_steps), tau_edge_steps=float(tau_max_steps) + 0.5 * float(delta_tau_steps), tstart_ms=float(tstart_ms), tstop_ms=float(tstop_ms), n_channels=int(n_channels), n_bins=int(n_bins), min_delay_steps=int(min_delay_steps), signature=signature, ) @staticmethod def _is_in_count_window(stamp_ms: float, tstart_ms: float, tstop_ms: float) -> bool: return (stamp_ms >= tstart_ms - 1e-12) and (stamp_ms <= tstop_ms + 1e-12) @staticmethod def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float: if isinstance(value, u.Quantity): value = u.get_mantissa(value / u.ms) dftype = brainstate.environ.dftype() arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar time value.') val = float(arr[0]) if (not allow_inf) and (not math.isfinite(val)): raise ValueError(f'{name} must be finite.') return val @staticmethod def _to_int_scalar(value, name: str) -> int: ditype = brainstate.environ.ditype() arr = np.asarray(u.math.asarray(value), dtype=ditype).reshape(-1) if arr.size != 1: raise ValueError(f'{name} must be a scalar integer value.') return int(arr[0]) @classmethod def _to_step_count( cls, value, dt_ms: float, name: str, allow_inf: bool = False, ): if value is None: if allow_inf: return math.inf raise ValueError(f'{name} cannot be None.') ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf) if math.isinf(ms): if allow_inf: return math.inf raise ValueError(f'{name} must be finite.') steps_f = ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError(f'{name} must be a multiple of the simulation resolution.') return steps_i def _time_to_step(self, t, dt_ms: float) -> int: t_ms = self._to_ms_scalar(t, name='t') steps_f = t_ms / dt_ms steps_i = int(np.rint(steps_f)) if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12): raise ValueError('Current simulation time t must be aligned to the simulation grid.') return steps_i @staticmethod def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool: if stamp_step <= t_min_steps: return False if math.isinf(t_max_steps): return True return stamp_step <= t_max_steps @staticmethod def _to_float_array( x, name: str, default: float = None, size: int = None, unit=None, ) -> np.ndarray: dftype = brainstate.environ.dftype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=dftype) else: if unit is not None and isinstance(x, u.Quantity): x = x / unit elif isinstance(x, u.Quantity): x = u.get_mantissa(x) arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1) if arr.size == 0 and size is not None: return np.zeros((0,), dtype=dftype) if not np.all(np.isfinite(arr)): raise ValueError(f'{name} must contain finite values.') if size is None: return arr if arr.size == 1 and size > 1: return np.full((size,), arr[0], dtype=dftype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.float64, copy=False) @staticmethod def _to_int_array( x, name: str, default: int = None, size: int = None, ) -> np.ndarray: ditype = brainstate.environ.ditype() if x is None: if default is None: raise ValueError(f'{name} cannot be None.') arr = np.asarray([default], dtype=ditype) else: arr = np.asarray(u.math.asarray(x), dtype=ditype).reshape(-1) if size is None: return arr.astype(np.int64, copy=False) if arr.size == 1 and size > 1: return np.full((size,), int(arr[0]), dtype=ditype) if arr.size != size: raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).') return arr.astype(np.int64, copy=False)