brainstate.transform.ProgressBar

brainstate.transform.ProgressBar#

class brainstate.transform.ProgressBar(freq=None, count=None, desc=None, **kwargs)#

A progress bar for tracking the progress of a jitted for-loop computation.

It can be used in for_loop(), checkpointed_for_loop(), scan(), and checkpointed_scan() functions. Or any other jitted function that uses a for-loop.

The message displayed in the progress bar can be customized by the following two methods:

  1. By passing a string to the desc argument.

  2. By passing a tuple with a string and a callable function to the desc argument. The callable function should take a dictionary as input and return a dictionary. The returned dictionary will be used to format the string.

In the second case, "i" denotes the iteration number and other keys can be computed from the loop outputs and carry values.

Parameters:
  • freq (int | None) – The frequency at which to print the progress bar. If not specified, the progress bar will be printed every 5% of the total iterations.

  • count (int | None) – The number of times to print the progress bar. If not specified, the progress bar will be printed every 5% of the total iterations. Cannot be used together with freq.

  • desc (Tuple[str, Callable[[Dict], Dict]] | str | None) – A description of the progress bar. If not specified, a default message will be displayed. Can be either a string or a tuple of (format_string, format_function).

  • **kwargs – Additional keyword arguments to pass to the progress bar.

Examples

Basic usage with default description:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def loop_fn(x):
...     return x ** 2
>>>
>>> xs = jnp.arange(100)
>>> pbar = brainstate.transform.ProgressBar()
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)

With custom description string:

>>> pbar = brainstate.transform.ProgressBar(desc="Running 1000 iterations")
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)

With frequency control:

>>> # Update every 10 iterations
>>> pbar = brainstate.transform.ProgressBar(freq=10)
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
>>>
>>> # Update exactly 20 times during execution
>>> pbar = brainstate.transform.ProgressBar(count=20)
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)

With dynamic description based on loop variables:

>>> state = brainstate.State(1.0)
>>>
>>> def loop_fn(x):
...     state.value += x
...     loss = jnp.sum(x ** 2)
...     return loss
>>>
>>> def format_desc(data):
...     return {"i": data["i"], "loss": data["y"], "state": data["carry"]}
>>>
>>> pbar = brainstate.transform.ProgressBar(
...     desc=("Iteration {i}, loss = {loss:.4f}, state = {state:.2f}", format_desc)
... )
>>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)

With scan function:

>>> def scan_fn(carry, x):
...     new_carry = carry + x
...     return new_carry, new_carry ** 2
>>>
>>> init_carry = 0.0
>>> pbar = brainstate.transform.ProgressBar(freq=5)
>>> final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs, pbar=pbar)
__init__(freq=None, count=None, desc=None, **kwargs)[source]#

Methods

__init__([freq, count, desc])

init(n)