Save and Load Checkpoints#

In this tutorial, we will explore how to save and load checkpoints in brainstate by using the orbax library and braintools library which provide a more lightweight approach. This is particularly useful for saving the state of your model during training so that you can resume training from where you left off or use the trained model for inference later. The following example demonstrates how to use orbax and braintools’s checkpointing functionality with a simple MLP model.

First you can install the orbax library by running the following command:

pip install orbax-checkpoint

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Orbax.

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

You can install the braintools library by running the following command:

pip install braintools

First, let’s import the necessary libraries.

import tempfile
import os

import jax
import jax.numpy as jnp
import orbax.checkpoint as orbax
import braintools
import brainstate 

Define the Model#

We define a simple Multi-Layer Perceptron (MLP) model using brainstate.

class MLP(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.dense1 = brainstate.nn.Linear(din, dmid)
        self.dense2 = brainstate.nn.Linear(dmid, dout)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x

Create the Model#

We create an instance of the model with a given seed for reproducibility.

SEED = 42
brainstate.random.seed(SEED)   # set seed in brainstate
model1 = MLP(10, 20, 30)    # create model
model1
MLP(
  dense1=Linear(
    in_size=(10,),
    out_size=(20,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[20]),
        'weight': ShapedArray(float32[10,20])
      }
    )
  ),
  dense2=Linear(
    in_size=(20,),
    out_size=(30,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[30]),
        'weight': ShapedArray(float32[20,30])
      }
    )
  )
)

Save the Model State#

Save the Model State Using orbax#

We save the model’s parameters to a checkpoint file.

tmpdir = tempfile.mkdtemp()    # create temporary directory

# Helper function to convert State objects to plain dictionaries for orbax
def to_plain_dict(obj):
    # Check if it's a dict-like object first
    if isinstance(obj, dict):
        return {k: to_plain_dict(v) for k, v in obj.items()}
    # Try to access 'value' attribute safely
    try:
        if 'value' in dir(obj):
            return to_plain_dict(obj.value)
    except (TypeError, AttributeError):
        pass
    # Return as-is if it's a leaf value (array, number, etc.)
    return obj

# Save using orbax - convert to plain dict for compatibility
state_nest = brainstate.graph.states(model1).to_nest()
state_plain = to_plain_dict(state_nest)
checkpointer = orbax.PyTreeCheckpointer()   # create checkpointer
checkpointer.save(os.path.join(tmpdir, 'state'), state_plain)    # save state

Now, we’ve saved the model’s parameters to the checkpoint files in tmpdir/state by using the orbax library.

Save the Model State Using braintools#

checkpoint = brainstate.graph.states(model1).to_nest()   # convert model to nest
braintools.file.msgpack_save(os.path.join(tmpdir, 'state.msgpack'), checkpoint)    # save checkpoint
Saving checkpoint into C:\Users\Administrator\AppData\Local\Temp\tmpnjecqtgi\state.msgpack

Now, we’ve saved the model’s parameters to the checkpoint files in tmpdir/state.msgpack by using the braintools library.

Load the Model State#

Load the Model State Using orbax#

Let’s load the model’s parameters from the checkpoint files.

# create a new model with the same structure
brainstate.random.seed(0)
model2 = MLP(10, 20, 30)

# Load the parameters from checkpoint files using orbax
checkpointer = orbax.PyTreeCheckpointer()
restored_state = checkpointer.restore(os.path.join(tmpdir, 'state'))

# Helper function to update model states from loaded dictionary
def update_from_dict(model_dict, loaded_dict):
    for key in model_dict:
        if isinstance(model_dict[key], dict) and isinstance(loaded_dict.get(key), dict):
            update_from_dict(model_dict[key], loaded_dict[key])
        elif hasattr(model_dict[key], 'value'):
            model_dict[key].value = loaded_dict[key]

# Update the model with the loaded state
model2_states = brainstate.graph.states(model2).to_nest()
update_from_dict(model2_states, restored_state)

Load the Model State Using braintools#

Let’s load the model’s parameters from the checkpoint files.

# Create a model with the same structure.
brainstate.random.seed(0)
model3 = MLP(10, 20, 30)
checkpoint = brainstate.graph.states(model3).to_nest()

# Read the model parameters from the msgpack file
braintools.file.msgpack_load(os.path.join(tmpdir, 'state.msgpack'), checkpoint)
Loading checkpoint from C:\Users\Administrator\AppData\Local\Temp\tmpnjecqtgi\state.msgpack
{'dense1': {'weight': ParamState(
    value={
      'bias': ShapedArray(float32[20]),
      'weight': ShapedArray(float32[10,20])
    }
  )},
 'dense2': {'weight': ParamState(
    value={
      'bias': ShapedArray(float32[30]),
      'weight': ShapedArray(float32[20,30])
    }
  )}}

Demonstrate the Loaded Model#

Let’s run the loaded model and check if it produces the same output as the original model.

y1 = model1(jnp.ones((1, 10)))
y2 = model2(jnp.ones((1, 10)))
y3 = model3(jnp.ones((1, 10)))
print(jnp.allclose(y1, y2))    # True
print(jnp.allclose(y1, y3))    # True
True
True