msgpack_save

Contents

msgpack_save#

class braintools.file.msgpack_save(filename, target, overwrite=True, async_manager=None, verbose=True)[source]#

Save a checkpoint of the model. Suitable for single-host using the msgpack library.

This function is rewritten from the Flax APIs (google/flax).

In this method, every JAX process saves the checkpoint on its own. Do not use it if you have multiple processes and you intend for them to save data to a common directory (e.g., a GCloud bucket). To save multi-process checkpoints to a shared storage or to save GlobalDeviceArray`s, use `multiprocess_save() instead.

Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling async_manager.wait_previous_save().

Parameters:
  • filename (str) – str or pathlib-like path to store checkpoint files in.

  • target (PyTree) – serializable object.

  • overwrite (bool) – overwrite existing checkpoint files if a checkpoint at the current or a later step already exists (default: True).

  • async_manager (AsyncManager | None) – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.

  • verbose (bool) – Whether output the print information.

Returns:

out – Filename of saved checkpoint.

Return type:

None