brainstate.transform.for_loop#
- brainstate.transform.for_loop(f, *xs, length=None, reverse=False, unroll=1, pbar=None)#
for-loopcontrol flow withState.- Parameters:
f (
Callable[...,TypeVar(Y)]) – A Python function to be looped over that accepts variadic arguments corresponding to slices ofxsalong their leading axes, and returns the output for that iteration.*xs – The values over which to loop along the leading axis, where each can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.
length (
int|None) – Optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays inxs(but can be used to perform loops where no inputxsare needed).reverse (
bool) – Optional boolean specifying whether to run the loop iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in bothxsand inys.unroll (
int|bool) – Optional positive int or bool specifying, in the underlying operation of the scan primitive, how many loop iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is completely unrolled (i.e. unroll=True) or left completely unrolled (i.e. unroll=False).pbar (
ProgressBar|int|None) – OptionalProgressBarinstance to display the progress of the loop operation.
- Returns:
The stacked outputs of
fwhen looped over the leading axis of the inputs.- Return type:
TypeVar(Y)
Examples
Basic for-loop operation:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> def process_item(x, y): ... return x * y + 1 >>> >>> xs = jnp.array([1.0, 2.0, 3.0]) >>> ys = jnp.array([4.0, 5.0, 6.0]) >>> results = brainstate.transform.for_loop(process_item, xs, ys)
For-loop with progress bar:
>>> pbar = brainstate.transform.ProgressBar(freq=10) >>> results = brainstate.transform.for_loop(process_item, xs, ys, pbar=pbar)
For-loop with reverse iteration:
>>> results = brainstate.transform.for_loop(process_item, xs, ys, reverse=True)