brainstate.transform.while_loop

Contents

brainstate.transform.while_loop#

brainstate.transform.while_loop(cond_fun, body_fun, init_val)#

Call body_fun repeatedly in a loop while cond_fun is True.

The Haskell-like type signature in brief is

while_loop :: (a -> Bool) -> (a -> a) -> a -> a

The semantics of while_loop are given by this Python implementation:

>>> def while_loop(cond_fun, body_fun, init_val):
...     val = init_val
...     while cond_fun(val):
...         val = body_fun(val)
...     return val

Unlike that Python version, while_loop is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an @jit function are unrolled, leading to large XLA computations.

Also unlike the Python analogue, the loop-carried value val must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type a in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Another difference from using Python-native loop constructs is that while_loop is not reverse-mode differentiable because XLA computations require static bounds on memory requirements.

Parameters:
  • cond_fun (Callable[[TypeVar(T)], Any]) – Function of type a -> Bool.

  • body_fun (Callable[[TypeVar(T)], TypeVar(T)]) – Function of type a -> a.

  • init_val (TypeVar(T)) – Value of type a, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value.

Returns:

The output from the final iteration of body_fun, of type a.

Return type:

TypeVar(T)

Examples

Basic while loop operation:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def cond_fn(val):
...     return val < 10
>>>
>>> def body_fn(val):
...     return val + 1
>>>
>>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0)
>>> # result will be 10

While loop with array state:

>>> def cond_fn(state):
...     return jnp.sum(state) < 100
>>>
>>> def body_fn(state):
...     return state * 1.1
>>>
>>> init_state = jnp.array([1.0, 2.0, 3.0])
>>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)

References