brainstate.transform.checkpointed_scan

brainstate.transform.checkpointed_scan#

brainstate.transform.checkpointed_scan(f, init, xs, length=None, base=16, pbar=None)#

Scan a function over leading array axes while carrying along state. This function is similar to scan() but with a checkpointed version.

Parameters:
  • f (Callable[[TypeVar(Carry), TypeVar(X)], Tuple[TypeVar(Carry), TypeVar(Y)]]) – A Python function to be scanned of type c -> a -> (c, b), meaning that f accepts two arguments where the first is a value of the loop carry and the second is a slice of xs along its leading axis, and that f returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.

  • init (TypeVar(Carry)) – An initial loop carry value of type c, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by f.

  • xs (TypeVar(X)) – The value of type [a] over which to scan along the leading axis, where [a] can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.

  • length (int | None) – Optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in xs (but can be used to perform scans where no input xs are needed).

  • base (int) – Optional integer specifying the base for the bounded scan loop.

  • pbar (ProgressBar | int | None) – Optional ProgressBar instance to display the progress of the scan operation.

Returns:

A pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs.

Return type:

Tuple[TypeVar(Carry), TypeVar(Y)]

Examples

Basic checkpointed scan operation:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def step_fn(carry, x):
...     return carry + x, carry * x
>>>
>>> init = 0.0
>>> xs = jnp.array([1.0, 2.0, 3.0])
>>> final_carry, ys = brainstate.transform.checkpointed_scan(step_fn, init, xs)

Using custom base for checkpointing:

>>> final_carry, ys = brainstate.transform.checkpointed_scan(
...     step_fn, init, xs, base=8
... )