# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from collections import deque
from dataclasses import dataclass
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTDevice
__all__ = [
'correlation_detector',
]
@dataclass
class _Spike:
timestep: int
weight: float
@dataclass
class _Calibration:
dt_ms: float
start_step: int
stop_step: float
origin_step: int
t_min_steps: int
t_max_steps: float
delta_tau_ms: float
delta_tau_steps: int
tau_max_ms: float
tau_max_steps: int
tau_edge_steps: float
tstart_ms: float
tstop_ms: float
n_bins: int
signature: tuple
class correlation_detector(NESTDevice):
r"""NEST-compatible ``correlation_detector`` device.
**1. Overview**
``correlation_detector`` receives spikes from two receptor ports
(``0`` and ``1``) and accumulates lag histograms in both weighted
(float64) and unweighted (int64) forms, following NEST event ordering.
It mirrors the semantics of the NEST ``correlation_detector`` recording
device, including dual-window filtering (activity window and counting
window), Kahan-compensated weighted histogram accumulation, and
NEST-compatible bin-edge conventions.
**2. Event Model and Histogram Equations**
Let an accepted event be represented as
:math:`e=(s, t, m, w)` where :math:`s\in\{0,1\}` is receptor port,
:math:`t` is integer simulation step, :math:`m` is multiplicity, and
:math:`w` is scalar connection weight. Each stored queue entry keeps
:math:`\hat{w}=m\cdot w`.
For each new event, the detector correlates it against all queued events
of the opposite port that survive lag-window pruning. The bin index is
.. math::
b = \left\lfloor
\frac{\tau_{\mathrm{edge}} + \sigma_s (t - t_j)}
{\Delta_\tau}
\right\rfloor,
\qquad
\sigma_s = 2s - 1,
\qquad
\tau_{\mathrm{edge}} = \tau_{\max} + \frac{\Delta_\tau}{2},
with all times represented in integer steps for the index computation.
:math:`\sigma_s` encodes causality direction: ``+1`` for port-1 events
(event is the "post" spike) and ``-1`` for port-0 events (event is the
"pre" spike), so positive lags correspond to port-1 spikes occurring
after port-0 spikes.
For each matched opposite event :math:`j`, the histograms are updated as
.. math::
H_b \leftarrow H_b + (m w)\,\hat{w}_j,
\qquad
C_b \leftarrow C_b + m,
where :math:`H_b` is ``histogram`` and :math:`C_b` is
``count_histogram``. ``histogram`` uses Kahan summation per bin to
reduce floating-point accumulation error; the compensation terms are
exposed as ``histogram_correction``.
The number of bins is
.. math::
N_{\mathrm{bins}} = 1 + 2 \cdot
\left(\frac{\tau_{\max,\mathrm{steps}}}{\Delta_{\tau,\mathrm{steps}}}\right).
Bin intervals are left-closed/right-open in the internal index rule, which
matches NEST edge handling in ``correlation_detector``. The centre bin
(index :math:`N_{\mathrm{bins}}//2`) corresponds to zero lag.
**3. Windowing, Assumptions, and Constraints**
Two windows are applied exactly as in NEST:
- **Activity window**:
:math:`(\mathrm{origin}+\mathrm{start},\ \mathrm{origin}+\mathrm{stop}]`
in simulation time. Events outside are discarded and never buffered.
- **Counting window**:
:math:`[\mathrm{Tstart},\ \mathrm{Tstop}]`. Only events in this window
increment ``n_events`` and update histograms. Events outside this window
can still be buffered and can affect later counted events via
cross-correlation with subsequently counted events.
Grid-alignment constraints are strict: ``start``, ``stop`` (if finite),
``origin``, ``delta_tau``, and ``tau_max`` must map to integer multiples
of simulation ``dt``. Additionally, ``tau_max`` must be an exact multiple
of ``delta_tau``. Violations raise ``ValueError`` at calibration time.
**4. Computational Implications**
Per accepted event, work is linear in queue lengths:
- :math:`O(Q_{\mathrm{other}})` for pruning and correlation against the
opposite-port queue,
- :math:`O(Q_{\mathrm{self}})` for sorted insertion into the sender queue.
Memory usage is :math:`O(Q_0 + Q_1 + N_{\mathrm{bins}})`, where queue
length depends on event rate and ``tau_max``. Calibration is triggered
lazily on first access; subsequent calls reuse cached state unless ``dt``
or window parameters change.
Parameters
----------
in_size : Size, optional
Output size/shape metadata consumed by :class:`brainstate.nn.Dynamics`.
This detector is event-driven and stores scalar histograms; ``in_size``
is retained for API consistency and does not affect histogram shape.
Default is ``1``.
delta_tau : quantity (ms) or float or None, optional
Bin width :math:`\Delta_\tau`. Unitful ``saiunit`` quantities are
accepted and converted to ms; bare floats are interpreted as ms.
Must be finite, strictly positive, and an integer multiple of
simulation ``dt``. ``None`` auto-selects ``5 * dt``.
Default is ``None``.
tau_max : quantity (ms) or float or None, optional
One-sided lag limit :math:`\tau_{\max}`. Unitful quantities accepted.
Must be finite, non-negative, an integer multiple of ``dt``, and an
exact integer multiple of ``delta_tau``. ``None`` auto-selects
``10 * delta_tau``. Default is ``None``.
Tstart : quantity (ms) or float, optional
Inclusive lower bound of the counting window in ms. Scalar-convertible;
unitful values are converted to ms. Default is ``0.0 * u.ms``.
Tstop : quantity (ms) or float or None, optional
Inclusive upper bound of the counting window in ms. ``None`` means
:math:`+\infty` (no upper bound). Scalar-convertible when provided.
Default is ``None``.
start : quantity (ms) or float, optional
Exclusive lower bound of the activity window relative to ``origin``
in ms. Must be scalar-convertible and aligned to simulation ``dt``.
Default is ``0.0 * u.ms``.
stop : quantity (ms) or float or None, optional
Inclusive upper bound of the activity window relative to ``origin``
in ms. Must be scalar-convertible and aligned to ``dt`` when finite.
``None`` means :math:`+\infty`. Default is ``None``.
origin : quantity (ms) or float, optional
Global time origin shift in ms for activity-window evaluation.
The effective activity window becomes
``(origin + start, origin + stop]``. Must be scalar-convertible
and aligned to ``dt``. Default is ``0.0 * u.ms``.
name : str or None, optional
Optional node name forwarded to :class:`brainstate.nn.Dynamics`.
If ``None``, a name is auto-generated. Default is ``None``.
Parameter Mapping
-----------------
.. list-table::
:header-rows: 1
:widths: 18 17 24 41
* - Parameter
- Default
- Math symbol
- Semantics
* - ``delta_tau``
- ``None`` → ``5 * dt``
- :math:`\Delta_\tau`
- Lag-histogram bin width; auto-resolved when omitted.
* - ``tau_max``
- ``None`` → ``10 * delta_tau``
- :math:`\tau_{\max}`
- One-sided correlation horizon; auto-resolved when omitted.
* - ``Tstart``
- ``0.0 ms``
- :math:`T_{\mathrm{start}}`
- Inclusive start of histogram and event-count update window.
* - ``Tstop``
- ``None`` (:math:`+\infty`)
- :math:`T_{\mathrm{stop}}`
- Inclusive end of histogram and event-count update window.
* - ``start``
- ``0.0 ms``
- :math:`t_{\mathrm{start,rel}}`
- Relative exclusive lower bound of the activity window.
* - ``stop``
- ``None`` (:math:`+\infty`)
- :math:`t_{\mathrm{stop,rel}}`
- Relative inclusive upper bound of the activity window.
* - ``origin``
- ``0.0 ms``
- :math:`t_0`
- Global offset applied to ``start`` and ``stop`` boundaries.
Raises
------
ValueError
If time parameters are non-scalar, non-finite where finite values are
required, misaligned to simulation resolution, or violate consistency
constraints (e.g. ``tau_max % delta_tau != 0`` or ``stop < start``).
Also raised for invalid runtime event arguments (unknown receptor port,
negative multiplicity, non-finite ``weights``, or size mismatches).
KeyError
If runtime environment keys such as ``'t'`` or simulation ``dt`` are
unavailable when calibration or update is attempted.
RuntimeError
If an internal lag-bin index falls outside histogram range; this
indicates inconsistency between calibration and event processing.
Notes
-----
- ``n_events`` can only be assigned ``[0, 0]``, which resets all detector
state and clears histograms, matching NEST's reset semantics.
- Runtime input events are provided through :meth:`update`:
``spikes``, ``receptor_ports``, ``weights``, ``multiplicities``, and
``stamp_steps`` are each scalar-broadcastable to a common 1-D event axis.
- Receptor ports are restricted to integer values ``0`` and ``1``.
- Connection delays are ignored by design; only event time stamps are used
for lag computation.
- Calibration is cached and reused across steps; it is automatically
invalidated if ``dt`` or any window parameter changes between calls.
Examples
--------
Basic correlation of simultaneous spikes on opposite ports:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> import numpy as np
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... det = brainpy.state.correlation_detector(
... delta_tau=0.5 * u.ms,
... tau_max=5.0 * u.ms,
... )
... with brainstate.environ.context(t=1.0 * u.ms):
... out = det.update(
... spikes=np.array([1.0, 1.0]),
... receptor_ports=np.array([0, 1]),
... weights=np.array([1.0, 2.0]),
... multiplicities=np.array([1, 1]),
... stamp_steps=np.array([11, 11]),
... )
... _ = out['histogram'].shape # (21,) for tau_max=5ms, delta_tau=0.5ms
Default parameters with no input events and explicit state reset:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... det = brainpy.state.correlation_detector()
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = det.update() # no input events; returns current state
... det.n_events = [0, 0] # explicit reset, NEST-compatible
References
----------
.. [1] NEST Simulator, ``correlation_detector`` model.
https://nest-simulator.readthedocs.io/en/stable/models/correlation_detector.html
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size = 1,
delta_tau: ArrayLike = None,
tau_max: ArrayLike = None,
Tstart: ArrayLike = 0.0 * u.ms,
Tstop: ArrayLike = None,
start: ArrayLike = 0.0 * u.ms,
stop: ArrayLike = None,
origin: ArrayLike = 0.0 * u.ms,
name: str = None,
):
super().__init__(in_size=in_size, name=name)
self.delta_tau = delta_tau
self.tau_max = tau_max
self.Tstart = Tstart
self.Tstop = Tstop
self.start = start
self.stop = stop
self.origin = origin
self._calib: _Calibration | None = None
self._incoming = [deque(), deque()]
ditype = brainstate.environ.ditype()
self._n_events = np.zeros((2,), dtype=ditype)
dftype = brainstate.environ.dftype()
self._histogram = np.zeros((0,), dtype=dftype)
self._histogram_correction = np.zeros((0,), dtype=dftype)
self._count_histogram = np.zeros((0,), dtype=ditype)
self._ensure_calibrated_from_env_if_available()
@property
def n_events(self) -> np.ndarray:
ditype = brainstate.environ.ditype()
return np.asarray(self._n_events, dtype=ditype)
@n_events.setter
def n_events(self, value):
ditype = brainstate.environ.ditype()
arr = np.asarray(u.math.asarray(value), dtype=ditype).reshape(-1)
if arr.size != 2 or arr[0] != 0 or arr[1] != 0:
raise ValueError('/n_events can only be set to [0 0].')
self._reset_state()
@property
def histogram(self) -> np.ndarray:
self._ensure_calibrated_from_env_if_available()
dftype = brainstate.environ.dftype()
return np.asarray(self._histogram, dtype=dftype)
@property
def histogram_correction(self) -> np.ndarray:
self._ensure_calibrated_from_env_if_available()
dftype = brainstate.environ.dftype()
return np.asarray(self._histogram_correction, dtype=dftype)
@property
def count_histogram(self) -> np.ndarray:
self._ensure_calibrated_from_env_if_available()
ditype = brainstate.environ.ditype()
return np.asarray(self._count_histogram, dtype=ditype)
[docs]
def get(self, key: str = 'histogram'):
r"""Return one detector state variable or calibrated scalar parameter.
Parameters
----------
key : str, optional
Query key. Supported values are ``'histogram'``,
``'histogram_correction'``, ``'count_histogram'``, ``'n_events'``,
``'delta_tau'``, ``'tau_max'``, ``'Tstart'``, ``'Tstop'``,
``'start'``, ``'stop'``, and ``'origin'``. Default is
``'histogram'``.
Returns
-------
out : dict
Requested value. Histogram outputs are NumPy arrays with shapes
``(N_bins,)`` (float64/int64). ``n_events`` has shape ``(2,)``.
Time scalar outputs are returned in milliseconds as Python
``float``; infinite bounds are returned as ``math.inf``.
Raises
------
KeyError
If ``key`` is unsupported.
ValueError
If scalar conversion of configured time parameters fails.
"""
if key == 'histogram':
return self.histogram
if key == 'histogram_correction':
return self.histogram_correction
if key == 'count_histogram':
return self.count_histogram
if key == 'n_events':
return self.n_events
if key == 'delta_tau':
self._ensure_calibrated_from_env_if_available()
return float(self._calib.delta_tau_ms) if self._calib is not None else None
if key == 'tau_max':
self._ensure_calibrated_from_env_if_available()
return float(self._calib.tau_max_ms) if self._calib is not None else None
if key == 'Tstart':
return self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True)
if key == 'Tstop':
stop_val = math.inf if self.Tstop is None else self.Tstop
return self._to_ms_scalar(stop_val, name='Tstop', allow_inf=True)
if key == 'start':
return self._to_ms_scalar(self.start, name='start')
if key == 'stop':
stop_val = math.inf if self.stop is None else self.stop
return self._to_ms_scalar(stop_val, name='stop', allow_inf=True)
if key == 'origin':
return self._to_ms_scalar(self.origin, name='origin')
raise KeyError(f'Unsupported key "{key}" for correlation_detector.get().')
[docs]
def connect(self):
r"""Compatibility no-op for NEST-like device interface.
"""
return None
[docs]
def flush(self):
r"""Return current detector outputs without consuming internal state.
"""
return {
'histogram': self.histogram,
'histogram_correction': self.histogram_correction,
'count_histogram': self.count_histogram,
'n_events': self.n_events,
}
[docs]
def init_state(self, batch_size: int = None, **kwargs):
r"""Reset detector buffers and histogram state for current calibration.
Parameters
----------
batch_size : int or None, optional
Unused placeholder for :class:`brainstate.nn.Dynamics`
compatibility.
**kwargs
Unused compatibility arguments.
"""
del batch_size, kwargs
self._ensure_calibrated_from_env_if_available()
self._reset_state()
[docs]
def update(
self,
spikes: ArrayLike = None,
receptor_ports: ArrayLike = None,
receptor_types: ArrayLike = None,
weights: ArrayLike = None,
multiplicities: ArrayLike = None,
stamp_steps: ArrayLike = None,
):
r"""Process one simulation step of incoming events and return outputs.
Parameters
----------
spikes : ArrayLike or None, optional
Event-presence/multiplicity proxy with shape ``(N,)`` after
flattening (scalars are broadcast). If ``None``, no events are
processed and current state is returned. When ``multiplicities`` is
``None``, integer-like ``spikes`` values are rounded and clipped to
non-negative multiplicities; otherwise non-integer values are
interpreted as binary ``spike > 0`` flags.
receptor_ports : ArrayLike or None, optional
Receptor port indices with shape ``(N,)`` (or scalar broadcast).
Valid values are ``0`` and ``1`` only. If ``None``, defaults to
``0`` for all events unless ``receptor_types`` is provided.
receptor_types : ArrayLike or None, optional
Alias for ``receptor_ports`` kept for NEST API compatibility. Used
only when ``receptor_ports`` is ``None``.
weights : ArrayLike or None, optional
Per-event connection weights with shape ``(N,)`` (or scalar
broadcast). Must be finite. Default is ``1.0`` when omitted.
multiplicities : ArrayLike or None, optional
Explicit non-negative integer multiplicities with shape ``(N,)``
(or scalar broadcast). Effective multiplicity is forced to zero
where corresponding ``spikes <= 0``.
stamp_steps : ArrayLike or None, optional
Integer event time stamps in simulation steps with shape ``(N,)``
(or scalar broadcast). If ``None``, all events are stamped at
``current_step + 1``.
Returns
-------
out : jax.Array
Same dictionary as :meth:`flush`, containing current
``histogram``, ``histogram_correction``, ``count_histogram``, and
``n_events`` after processing this call.
Raises
------
KeyError
If required environment values (``'t'`` or ``dt``) are missing.
ValueError
If argument sizes are inconsistent, receptor ports are outside
``{0, 1}``, multiplicities are negative, times are not grid-aligned,
or calibration constraints are violated.
RuntimeError
If a computed bin index is outside histogram bounds.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
calib = self._ensure_calibrated(dt)
step_now = self._time_to_step(t, calib.dt_ms)
if spikes is None:
return self.flush()
spike_arr = self._to_float_array(spikes, name='spikes')
if spike_arr.size == 0:
return self.flush()
n_items = spike_arr.size
if receptor_ports is None and receptor_types is not None:
receptor_ports = receptor_types
port_arr = self._to_int_array(receptor_ports, name='receptor_ports', default=0, size=n_items)
weight_arr = self._to_float_array(weights, name='weights', default=1.0, size=n_items)
if multiplicities is None:
rounded = np.rint(spike_arr)
is_integer_like = np.allclose(spike_arr, rounded, atol=1e-12, rtol=1e-12)
if is_integer_like:
counts = np.maximum(rounded.astype(np.int64), 0)
else:
counts = (spike_arr > 0.0).astype(np.int64)
else:
mult_arr = self._to_int_array(multiplicities, name='multiplicities', size=n_items)
if np.any(mult_arr < 0):
raise ValueError('multiplicities must be non-negative.')
counts = np.where(spike_arr > 0.0, mult_arr, 0)
if stamp_steps is None:
ditype = brainstate.environ.ditype()
stamp_arr = np.full((n_items,), step_now + 1, dtype=ditype)
else:
stamp_arr = self._to_int_array(stamp_steps, name='stamp_steps', size=n_items)
for i in range(n_items):
multiplicity = int(counts[i])
if multiplicity <= 0:
continue
sender = int(port_arr[i])
if sender < 0 or sender > 1:
raise ValueError(f'Unknown receptor_type {sender} for correlation_detector.')
stamp_step = int(stamp_arr[i])
if not self._is_active(stamp_step, calib.t_min_steps, calib.t_max_steps):
continue
self._handle_event(
sender=sender,
stamp_step=stamp_step,
weight=float(weight_arr[i]),
multiplicity=multiplicity,
calib=calib,
)
return self.flush()
def _handle_event(
self,
sender: int,
stamp_step: int,
weight: float,
multiplicity: int,
calib: _Calibration,
):
other = 1 - sender
other_spikes = self._incoming[other]
while len(other_spikes) > 0:
dt_steps = stamp_step - other_spikes[0].timestep
if dt_steps - 0.5 * other >= calib.tau_edge_steps:
other_spikes.popleft()
else:
break
stamp_ms = float(stamp_step) * calib.dt_ms
if self._is_in_count_window(stamp_ms, calib.tstart_ms, calib.tstop_ms):
self._n_events[sender] += 1
sign = 2 * sender - 1
own_weight = float(multiplicity) * float(weight)
for spike_j in other_spikes:
bin_index = int(
math.floor(
(calib.tau_edge_steps + sign * (stamp_step - spike_j.timestep))
/ calib.delta_tau_steps
)
)
if bin_index < 0 or bin_index >= self._histogram.size:
raise RuntimeError('Internal bin index out of range in correlation_detector.')
y = own_weight * spike_j.weight - self._histogram_correction[bin_index]
t = self._histogram[bin_index] + y
self._histogram_correction[bin_index] = (t - self._histogram[bin_index]) - y
self._histogram[bin_index] = t
self._count_histogram[bin_index] += multiplicity
spike_entry = _Spike(timestep=stamp_step, weight=float(multiplicity) * float(weight))
queue = self._incoming[sender]
insert_pos = len(queue)
for idx, old_spike in enumerate(queue):
if old_spike.timestep > stamp_step:
insert_pos = idx
break
queue.insert(insert_pos, spike_entry)
def _ensure_calibrated_from_env_if_available(self):
try:
dt = brainstate.environ.get_dt()
except KeyError:
return
self._ensure_calibrated(dt)
def _ensure_calibrated(self, dt) -> _Calibration:
new_calib = self._compute_calibration(dt)
if self._calib is None or self._calib.signature != new_calib.signature:
self._calib = new_calib
self._reset_state()
return self._calib
def _reset_state(self):
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
self._n_events = np.zeros((2,), dtype=ditype)
self._incoming = [deque(), deque()]
if self._calib is None:
self._histogram = np.zeros((0,), dtype=dftype)
self._histogram_correction = np.zeros((0,), dtype=dftype)
self._count_histogram = np.zeros((0,), dtype=ditype)
return
n_bins = int(self._calib.n_bins)
self._histogram = np.zeros((n_bins,), dtype=dftype)
self._histogram_correction = np.zeros((n_bins,), dtype=dftype)
self._count_histogram = np.zeros((n_bins,), dtype=ditype)
def _compute_calibration(self, dt) -> _Calibration:
dt_ms = self._to_ms_scalar(dt, name='dt')
if dt_ms <= 0.0:
raise ValueError('Simulation resolution dt must be positive.')
start_steps = self._to_step_count(self.start, dt_ms, 'start')
stop_value = math.inf if self.stop is None else self.stop
stop_steps = self._to_step_count(stop_value, dt_ms, 'stop', allow_inf=True)
if not math.isinf(stop_steps) and stop_steps < start_steps:
raise ValueError('stop >= start required.')
origin_steps = self._to_step_count(self.origin, dt_ms, 'origin')
t_min_steps = origin_steps + start_steps
t_max_steps = math.inf if math.isinf(stop_steps) else origin_steps + stop_steps
if self.delta_tau is None:
delta_tau_ms = 5.0 * dt_ms
else:
delta_tau_ms = self._to_ms_scalar(self.delta_tau, name='delta_tau')
if not math.isfinite(delta_tau_ms) or delta_tau_ms <= 0.0:
raise ValueError('delta_tau must be positive and finite.')
delta_tau_steps = self._to_step_count(delta_tau_ms, dt_ms, 'delta_tau')
if self.tau_max is None:
tau_max_ms = 10.0 * delta_tau_ms
else:
tau_max_ms = self._to_ms_scalar(self.tau_max, name='tau_max')
if not math.isfinite(tau_max_ms) or tau_max_ms < 0.0:
raise ValueError('tau_max must be finite and non-negative.')
tau_max_steps = self._to_step_count(tau_max_ms, dt_ms, 'tau_max')
if tau_max_steps % delta_tau_steps != 0:
raise ValueError('tau_max must be a multiple of delta_tau.')
tstart_ms = self._to_ms_scalar(self.Tstart, name='Tstart', allow_inf=True)
tstop_value = math.inf if self.Tstop is None else self.Tstop
tstop_ms = self._to_ms_scalar(tstop_value, name='Tstop', allow_inf=True)
n_bins = int(1 + 2 * (tau_max_steps // delta_tau_steps))
signature = (
float(dt_ms),
int(start_steps),
float(stop_steps),
int(origin_steps),
int(t_min_steps),
float(t_max_steps),
float(delta_tau_ms),
int(delta_tau_steps),
float(tau_max_ms),
int(tau_max_steps),
float(tstart_ms),
float(tstop_ms),
int(n_bins),
)
return _Calibration(
dt_ms=float(dt_ms),
start_step=int(start_steps),
stop_step=float(stop_steps),
origin_step=int(origin_steps),
t_min_steps=int(t_min_steps),
t_max_steps=float(t_max_steps),
delta_tau_ms=float(delta_tau_ms),
delta_tau_steps=int(delta_tau_steps),
tau_max_ms=float(tau_max_ms),
tau_max_steps=int(tau_max_steps),
tau_edge_steps=float(tau_max_steps) + 0.5 * float(delta_tau_steps),
tstart_ms=float(tstart_ms),
tstop_ms=float(tstop_ms),
n_bins=int(n_bins),
signature=signature,
)
@staticmethod
def _is_in_count_window(stamp_ms: float, tstart_ms: float, tstop_ms: float) -> bool:
return (stamp_ms >= tstart_ms - 1e-12) and (stamp_ms <= tstop_ms + 1e-12)
@staticmethod
def _to_ms_scalar(value, name: str, allow_inf: bool = False) -> float:
if isinstance(value, u.Quantity):
value = u.get_mantissa(value / u.ms)
dftype = brainstate.environ.dftype()
arr = np.asarray(u.math.asarray(value), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be a scalar time value.')
val = float(arr[0])
if (not allow_inf) and (not math.isfinite(val)):
raise ValueError(f'{name} must be finite.')
return val
@classmethod
def _to_step_count(
cls,
value,
dt_ms: float,
name: str,
allow_inf: bool = False,
):
if value is None:
if allow_inf:
return math.inf
raise ValueError(f'{name} cannot be None.')
ms = cls._to_ms_scalar(value, name=name, allow_inf=allow_inf)
if math.isinf(ms):
if allow_inf:
return math.inf
raise ValueError(f'{name} must be finite.')
steps_f = ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError(f'{name} must be a multiple of the simulation resolution.')
return steps_i
def _time_to_step(self, t, dt_ms: float) -> int:
t_ms = self._to_ms_scalar(t, name='t')
steps_f = t_ms / dt_ms
steps_i = int(np.rint(steps_f))
if not np.isclose(steps_f, steps_i, atol=1e-12, rtol=1e-12):
raise ValueError('Current simulation time t must be aligned to the simulation grid.')
return steps_i
@staticmethod
def _is_active(stamp_step: int, t_min_steps: int, t_max_steps: float) -> bool:
if stamp_step <= t_min_steps:
return False
if math.isinf(t_max_steps):
return True
return stamp_step <= t_max_steps
@staticmethod
def _to_float_array(
x,
name: str,
default: float = None,
size: int = None,
unit=None,
) -> np.ndarray:
dftype = brainstate.environ.dftype()
if x is None:
if default is None:
raise ValueError(f'{name} cannot be None.')
arr = np.asarray([default], dtype=dftype)
else:
if unit is not None and isinstance(x, u.Quantity):
x = x / unit
elif isinstance(x, u.Quantity):
x = u.get_mantissa(x)
arr = np.asarray(u.math.asarray(x), dtype=dftype).reshape(-1)
if arr.size == 0 and size is not None:
return np.zeros((0,), dtype=dftype)
if not np.all(np.isfinite(arr)):
raise ValueError(f'{name} must contain finite values.')
if size is None:
return arr
if arr.size == 1 and size > 1:
return np.full((size,), arr[0], dtype=dftype)
if arr.size != size:
raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).')
return arr.astype(np.float64, copy=False)
@staticmethod
def _to_int_array(
x,
name: str,
default: int = None,
size: int = None,
) -> np.ndarray:
ditype = brainstate.environ.ditype()
if x is None:
if default is None:
raise ValueError(f'{name} cannot be None.')
arr = np.asarray([default], dtype=ditype)
else:
arr = np.asarray(u.math.asarray(x), dtype=ditype).reshape(-1)
if size is None:
return arr.astype(np.int64, copy=False)
if arr.size == 1 and size > 1:
return np.full((size,), int(arr[0]), dtype=ditype)
if arr.size != size:
raise ValueError(f'{name} size ({arr.size}) does not match spikes size ({size}).')
return arr.astype(np.int64, copy=False)