brainstate.transform.bounded_while_loop

brainstate.transform.bounded_while_loop#

brainstate.transform.bounded_while_loop(cond_fun, body_fun, init_val, *, max_steps, base=16)#

While loop with a bound on the maximum number of steps.

This function is adapted from while_loop in equinox.

This function is useful when you want to ensure that a while loop terminates even if the condition function is never false. The function is implemented using a scan operation, so it is reverse-mode differentiable.

Parameters:
  • cond_fun (Callable[[TypeVar(T)], Any]) – A function of type a -> Bool.

  • body_fun (Callable[[TypeVar(T)], TypeVar(T)]) – A function of type a -> a.

  • init_val (TypeVar(T)) – The initial value of type a.

  • max_steps (int) – A bound on the maximum number of steps, after which the loop terminates unconditionally.

  • base (int) – Run time will increase slightly as base increases. Compilation time will decrease substantially as math.ceil(math.log(max_steps, base)) decreases. (Which happens as base increases.)

Returns:

The final value, as if computed by a lax.while_loop.

Return type:

T

Examples

Basic bounded while loop:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def cond_fn(val):
...     return val < 1000  # This might never be false
>>>
>>> def body_fn(val):
...     return val * 2
>>>
>>> # Loop will terminate after at most 10 steps
>>> result = brainstate.transform.bounded_while_loop(
...     cond_fn, body_fn, 1, max_steps=10
... )

Bounded while loop with custom base:

>>> # Use a smaller base for potentially faster compilation
>>> result = brainstate.transform.bounded_while_loop(
...     cond_fn, body_fn, 1, max_steps=100, base=8
... )

Bounded while loop with array state:

>>> def cond_fn(state):
...     return jnp.max(state) < 100
>>>
>>> def body_fn(state):
...     return state + jnp.array([1.0, 2.0, 3.0])
>>>
>>> init_state = jnp.array([0.0, 0.0, 0.0])
>>> final_state = brainstate.transform.bounded_while_loop(
...     cond_fn, body_fn, init_state, max_steps=50
... )