Delay#

class brainstate.nn.Delay(target_info, time=None, init=None, entries=None, interpolation=None, take_aware_unit=False, update_every=None, delay_method='rotation', interp_method='linear_interp')#

Delay variable for storing short-term history data.

The data in this delay variable is arranged as:

delay = 0             [ data
delay = 1               data
delay = 2               data
...                     ....
...                     ....
delay = length-1        data
delay = length          data ]
Parameters:
  • time (int | float | Quantity | None) – int, float, or Quantity. The delay time.

  • init (Callable | Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) –

    Any. The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of Connector. Note that initial_delay_data should be arranged as the following way:

    delay = 1             [ data
    delay = 2               data
    ...                     ....
    ...                     ....
    delay = length-1        data
    delay = length          data ]
    

  • entries (Dict | None) – optional, dict. The delay access entries.

  • interpolation (str | Callable | None) – str or Callable. The interpolation method for continuous-time retrieval. Built-in methods: ‘nearest’, ‘linear’, ‘cubic’, ‘hermite’, ‘polynomial2’, ‘polynomial3’. Can also be a custom callable following the InterpolationMethod protocol.

  • take_aware_unit (bool) – bool. Whether to track and preserve units from brainunit.

  • update_every (float | Quantity | None) – optional, float or Quantity. Time interval between buffer updates. If None (default), the buffer is updated every time update() is called. If specified, the buffer is only updated when the accumulated time since the last update exceeds this threshold. Supports brainunit quantities (e.g., 5.0*u.ms). Example: update_every=5.0 means update every 5 time units.

  • update_strategy – str. Strategy for handling updates between threshold crossings. Options: - ‘hold’ (default): Skip writes between thresholds, keep last written value. - ‘latest’: Always cache the newest value, write it when threshold is crossed. - ‘aggregate’: Accumulate all values between thresholds and write aggregated result.

  • aggregate_fn – optional, str or Callable. Aggregation function for ‘aggregate’ strategy. Built-in options (strings): ‘mean’, ‘sum’, ‘max’, ‘min’, ‘last’. Custom: Any callable that takes an array and axis parameter and returns aggregated value. Default: ‘mean’ when update_strategy=’aggregate’. Ignored for other strategies.

  • delay_method (str | None) – str. Deprecated parameter kept for backward compatibility. The unified ring buffer implementation now uses rotation for all delays.

  • interp_method (str) – str. Deprecated parameter kept for backward compatibility. Use ‘interpolation’ parameter instead.

Examples

Basic delay with default behavior (update every call):

>>> import brainstate
>>> import jax.numpy as jnp
>>> delay = brainstate.nn.Delay(jnp.zeros((10,)), time=5.0)
>>> delay.init_state()
>>> for i in range(100):
...     delay.update(jnp.ones((10,)) * i)

Delay with update frequency control (hold strategy):

>>> import brainunit as u
>>> delay = brainstate.nn.Delay(
...     jnp.zeros((10,)),
...     time=10.0 * u.ms,
...     update_every=5.0 * u.ms,  # Update every 5ms
... )
>>> delay.init_state()
access(entry, *time_and_idx)[source]#

Create a DelayAccess object for a specific delay entry and delay time.

Parameters:
  • entry (str) – The name of the delay entry to access.

  • time_and_idx (Sequence) – The delay time or parameters associated with the entry.

Returns:

An object that provides access to the delay data for the specified entry and time.

Return type:

DelayAccess

at(entry)[source]#

Get the data at the given entry.

Parameters:

entry (str) – str. The entry to access the data.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Returns:

The data.

init_state(batch_size=None, **kwargs)[source]#

State initialization function.

register_delay(*time_and_idx)[source]#

Register delay times and update the maximum delay configuration.

This method processes one or more delay times, validates their format and consistency, and updates the delay buffer size if necessary. It handles both scalar and vector delay times, ensuring all vector delays have the same size.

Parameters:

*time_and_idx – Variable number of delay time arguments. The first argument should be the primary delay time (float, int, or array-like). Additional arguments are treated as indices or secondary delay parameters. All delay times should be non-negative numbers or arrays of the same size.

Returns:

If time_and_index[0] is None, returns None. Otherwise, returns a tuple

containing (delay_step, *time_and_index[1:]) where delay_step is the computed delay step in integer time units, and the remaining elements are the additional delay parameters passed in.

Return type:

tuple or None

Raises:
  • AssertionError – If no delay time is provided (empty time_and_index).

  • ValueError – If delay times have inconsistent sizes when using vector delays, or if delay times are not scalar or 1D arrays.

Note

  • The method updates self.max_time and self.max_length if the new delay requires a larger buffer size.

  • Delay steps are computed using the current environment time step (dt).

  • All delay indices (time_and_index[1:]) must be integers.

  • Vector delays must all have the same size as the first delay time.

Example

>>> delay_obj.register_delay(5.0)  # Register 5ms delay
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1)  # Vector delay with indices
register_entry(entry, *time_and_idx)[source]#

Register an entry to access the delay data.

Parameters:
  • entry (str) – str. The entry to access the delay data.

  • time_and_idx – The delay time of the entry, the first element is the delay time, the second and later element is the index.

Return type:

Delay

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.

retrieve_at_step(delay_step, *indices)[source]#

Retrieve the delay data at the given delay time step (the integer to indicate the time step).

Parameters:
  • delay_step (int_like) – Retrieve the data at the given time step.

  • indices (tuple) – The indices to slice the data.

Returns:

delay_data

Return type:

PyTree

retrieve_at_time(delay_time, *indices)[source]#

Retrieve the delay data at the given delay time step (the integer to indicate the time step).

Parameters:
  • delay_time (float) – Retrieve the data at the given time.

  • indices (tuple) – The indices to slice the data.

Returns:

delay_data

Return type:

PyTree

update(current)[source]#

Update delay variable with the new data.

Return type:

None