brainstate.transform.bounded_while_loop#
- brainstate.transform.bounded_while_loop(cond_fun, body_fun, init_val, *, max_steps, base=16)#
While loop with a bound on the maximum number of steps.
This function is adapted from
while_loopin equinox.This function is useful when you want to ensure that a while loop terminates even if the condition function is never false. The function is implemented using a scan operation, so it is reverse-mode differentiable.
- Parameters:
cond_fun (
Callable[[TypeVar(T)],Any]) – A function of typea -> Bool.body_fun (
Callable[[TypeVar(T)],TypeVar(T)]) – A function of typea -> a.init_val (
TypeVar(T)) – The initial value of typea.max_steps (
int) – A bound on the maximum number of steps, after which the loop terminates unconditionally.base (
int) – Run time will increase slightly as base increases. Compilation time will decrease substantially as math.ceil(math.log(max_steps, base)) decreases. (Which happens as base increases.)
- Returns:
The final value, as if computed by a lax.while_loop.
- Return type:
Examples
Basic bounded while loop:
>>> import brainstate >>> import jax.numpy as jnp >>> >>> def cond_fn(val): ... return val < 1000 # This might never be false >>> >>> def body_fn(val): ... return val * 2 >>> >>> # Loop will terminate after at most 10 steps >>> result = brainstate.transform.bounded_while_loop( ... cond_fn, body_fn, 1, max_steps=10 ... )
Bounded while loop with custom base:
>>> # Use a smaller base for potentially faster compilation >>> result = brainstate.transform.bounded_while_loop( ... cond_fn, body_fn, 1, max_steps=100, base=8 ... )
Bounded while loop with array state:
>>> def cond_fn(state): ... return jnp.max(state) < 100 >>> >>> def body_fn(state): ... return state + jnp.array([1.0, 2.0, 3.0]) >>> >>> init_state = jnp.array([0.0, 0.0, 0.0]) >>> final_state = brainstate.transform.bounded_while_loop( ... cond_fn, body_fn, init_state, max_steps=50 ... )