brainstate.transform.while_loop#
- brainstate.transform.while_loop(cond_fun, body_fun, init_val)#
Call
body_funrepeatedly in a loop whilecond_funis True.The Haskell-like type signature in brief is
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
The semantics of
while_loopare given by this Python implementation:>>> def while_loop(cond_fun, body_fun, init_val): ... val = init_val ... while cond_fun(val): ... val = body_fun(val) ... return val
Unlike that Python version,
while_loopis a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an@jitfunction are unrolled, leading to large XLA computations.Also unlike the Python analogue, the loop-carried value
valmust hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typeain the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Another difference from using Python-native loop constructs is that
while_loopis not reverse-mode differentiable because XLA computations require static bounds on memory requirements.- Parameters:
cond_fun (
Callable[[TypeVar(T)],Any]) – Function of typea -> Bool.body_fun (
Callable[[TypeVar(T)],TypeVar(T)]) – Function of typea -> a.init_val (
TypeVar(T)) – Value of typea, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.
- Returns:
The output from the final iteration of body_fun, of type
a.- Return type:
TypeVar(T)
Examples
Basic while loop operation:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> def cond_fn(val): ... return val < 10 >>> >>> def body_fn(val): ... return val + 1 >>> >>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0) >>> # result will be 10
While loop with array state:
>>> def cond_fn(state): ... return jnp.sum(state) < 100 >>> >>> def body_fn(state): ... return state * 1.1 >>> >>> init_state = jnp.array([1.0, 2.0, 3.0]) >>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)
References