brainstate.transform.checkpointed_for_loop#
- brainstate.transform.checkpointed_for_loop(f, *xs, length=None, base=16, pbar=None)#
for-loopcontrol flow withStatewith a checkpointed version, similar tofor_loop().- Parameters:
f (
Callable[...,TypeVar(Y)]) – A Python function to be looped over that accepts variadic arguments corresponding to slices ofxsalong their leading axes, and returns the output for that iteration.*xs (
TypeVar(X)) – The values over which to loop along the leading axis, where each can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.length (
int|None) – Optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays inxs(but can be used to perform loops where no inputxsare needed).base (
int) – Optional integer specifying the base for the bounded loop.pbar (
ProgressBar|int|None) – OptionalProgressBarinstance to display the progress of the loop operation.
- Returns:
The stacked outputs of
fwhen looped over the leading axis of the inputs.- Return type:
TypeVar(Y)
Examples
Basic checkpointed for-loop operation:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> def process_item(x, y): ... return x * y + 1 >>> >>> xs = jnp.array([1.0, 2.0, 3.0]) >>> ys = jnp.array([4.0, 5.0, 6.0]) >>> results = brainstate.transform.checkpointed_for_loop(process_item, xs, ys)
Using custom base for checkpointing:
>>> results = brainstate.transform.checkpointed_for_loop( ... process_item, xs, ys, base=8 ... )