msgpack_load

Contents

msgpack_load#

class braintools.file.msgpack_load(filename, target=None, parallel=True, mismatch='error', verbose=True)[source]#

Load the checkpoint from the given checkpoint path using the msgpack library.

This function is rewritten from the Flax APIs (google/flax).

Parameters:
  • filename (str) – checkpoint file or directory of checkpoints to restore from.

  • target (Any | None) – the object to restore the state into. If None, the state is returned as a dict.

  • parallel (bool) – whether to load seekable checkpoints in parallel, for speed.

  • mismatch (Literal['error', 'warn', 'ignore']) – How to handle mismatches between target and state dict. ‘error’ (default): raise ValueError on mismatch ‘warn’: issue warning and skip mismatched keys ‘ignore’: silently skip mismatched keys

  • verbose (bool) – Whether output the print information.

Returns:

out – Restored target updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in target unchanged. If a file path is specified and is not found, the passed-in target will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created.

Return type:

PyTree