brainstate.transform.checkpointed_for_loop

brainstate.transform.checkpointed_for_loop#

brainstate.transform.checkpointed_for_loop(f, *xs, length=None, base=16, pbar=None)#

for-loop control flow with State with a checkpointed version, similar to for_loop().

Parameters:
  • f (Callable[..., TypeVar(Y)]) – A Python function to be looped over that accepts variadic arguments corresponding to slices of xs along their leading axes, and returns the output for that iteration.

  • *xs (TypeVar(X)) – The values over which to loop along the leading axis, where each can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.

  • length (int | None) – Optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in xs (but can be used to perform loops where no input xs are needed).

  • base (int) – Optional integer specifying the base for the bounded loop.

  • pbar (ProgressBar | int | None) – Optional ProgressBar instance to display the progress of the loop operation.

Returns:

The stacked outputs of f when looped over the leading axis of the inputs.

Return type:

TypeVar(Y)

Examples

Basic checkpointed for-loop operation:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def process_item(x, y):
...     return x * y + 1
>>>
>>> xs = jnp.array([1.0, 2.0, 3.0])
>>> ys = jnp.array([4.0, 5.0, 6.0])
>>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)

Using custom base for checkpointing:

>>> results = brainstate.transform.checkpointed_for_loop(
...     process_item, xs, ys, base=8
... )