Delay#
- class brainstate.nn.Delay(target_info, time=None, init=None, entries=None, interpolation=None, take_aware_unit=False, update_every=None, delay_method='rotation', interp_method='linear_interp')#
Delay variable for storing short-term history data.
The data in this delay variable is arranged as:
delay = 0 [ data delay = 1 data delay = 2 data ... .... ... .... delay = length-1 data delay = length data ]
- Parameters:
time (
int|float|Quantity|None) – int, float, or Quantity. The delay time.init (
Callable|Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) –Any. The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of
Connector. Note thatinitial_delay_datashould be arranged as the following way:delay = 1 [ data delay = 2 data ... .... ... .... delay = length-1 data delay = length data ]
entries (
Dict|None) – optional, dict. The delay access entries.interpolation (
str|Callable|None) – str or Callable. The interpolation method for continuous-time retrieval. Built-in methods: ‘nearest’, ‘linear’, ‘cubic’, ‘hermite’, ‘polynomial2’, ‘polynomial3’. Can also be a custom callable following the InterpolationMethod protocol.take_aware_unit (
bool) – bool. Whether to track and preserve units from brainunit.update_every (
float|Quantity|None) – optional, float or Quantity. Time interval between buffer updates. If None (default), the buffer is updated every time update() is called. If specified, the buffer is only updated when the accumulated time since the last update exceeds this threshold. Supports brainunit quantities (e.g., 5.0*u.ms). Example: update_every=5.0 means update every 5 time units.update_strategy – str. Strategy for handling updates between threshold crossings. Options: - ‘hold’ (default): Skip writes between thresholds, keep last written value. - ‘latest’: Always cache the newest value, write it when threshold is crossed. - ‘aggregate’: Accumulate all values between thresholds and write aggregated result.
aggregate_fn – optional, str or Callable. Aggregation function for ‘aggregate’ strategy. Built-in options (strings): ‘mean’, ‘sum’, ‘max’, ‘min’, ‘last’. Custom: Any callable that takes an array and axis parameter and returns aggregated value. Default: ‘mean’ when update_strategy=’aggregate’. Ignored for other strategies.
delay_method (
str|None) – str. Deprecated parameter kept for backward compatibility. The unified ring buffer implementation now uses rotation for all delays.interp_method (
str) – str. Deprecated parameter kept for backward compatibility. Use ‘interpolation’ parameter instead.
Examples
Basic delay with default behavior (update every call):
>>> import brainstate >>> import jax.numpy as jnp >>> delay = brainstate.nn.Delay(jnp.zeros((10,)), time=5.0) >>> delay.init_state() >>> for i in range(100): ... delay.update(jnp.ones((10,)) * i)
Delay with update frequency control (hold strategy):
>>> import brainunit as u >>> delay = brainstate.nn.Delay( ... jnp.zeros((10,)), ... time=10.0 * u.ms, ... update_every=5.0 * u.ms, # Update every 5ms ... ) >>> delay.init_state()
- access(entry, *time_and_idx)[source]#
Create a DelayAccess object for a specific delay entry and delay time.
- Parameters:
entry (
str) – The name of the delay entry to access.time_and_idx (Sequence) – The delay time or parameters associated with the entry.
- Returns:
An object that provides access to the delay data for the specified entry and time.
- Return type:
- register_delay(*time_and_idx)[source]#
Register delay times and update the maximum delay configuration.
This method processes one or more delay times, validates their format and consistency, and updates the delay buffer size if necessary. It handles both scalar and vector delay times, ensuring all vector delays have the same size.
- Parameters:
*time_and_idx – Variable number of delay time arguments. The first argument should be the primary delay time (float, int, or array-like). Additional arguments are treated as indices or secondary delay parameters. All delay times should be non-negative numbers or arrays of the same size.
- Returns:
- If time_and_index[0] is None, returns None. Otherwise, returns a tuple
containing (delay_step, *time_and_index[1:]) where delay_step is the computed delay step in integer time units, and the remaining elements are the additional delay parameters passed in.
- Return type:
tuple or None
- Raises:
AssertionError – If no delay time is provided (empty time_and_index).
ValueError – If delay times have inconsistent sizes when using vector delays, or if delay times are not scalar or 1D arrays.
Note
The method updates self.max_time and self.max_length if the new delay requires a larger buffer size.
Delay steps are computed using the current environment time step (dt).
All delay indices (time_and_index[1:]) must be integers.
Vector delays must all have the same size as the first delay time.
Example
>>> delay_obj.register_delay(5.0) # Register 5ms delay >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
- retrieve_at_step(delay_step, *indices)[source]#
Retrieve the delay data at the given delay time step (the integer to indicate the time step).