{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dbdef1f7bce3a135",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# Save and Load Checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43f961f5",
   "metadata": {},
   "source": [
    "In this tutorial, we will explore how to save and load checkpoints in `brainstate` by using the `orbax` library and `braintools` library which provide a more lightweight approach. This is particularly useful for saving the state of your model during training so that you can resume training from where you left off or use the trained model for inference later. The following example demonstrates how to use `orbax` and `braintools`'s checkpointing functionality with a simple MLP model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "343e09cf",
   "metadata": {},
   "source": [
    "First you can install the `orbax` library by running the following command:\n",
    "\n",
    "`pip install orbax-checkpoint`\n",
    "\n",
    "You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Orbax.\n",
    "\n",
    "`pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'`\n",
    "\n",
    "You can install the `braintools` library by running the following command:\n",
    "\n",
    "`pip install braintools`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee756112",
   "metadata": {},
   "source": [
    "First, let's import the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b7741c32",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:19.550147Z",
     "start_time": "2025-10-11T10:05:17.642468Z"
    }
   },
   "outputs": [],
   "source": [
    "import tempfile\n",
    "import os\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import orbax.checkpoint as orbax\n",
    "import braintools\n",
    "import brainstate "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6eb2d76",
   "metadata": {},
   "source": [
    "## Define the Model\n",
    "We define a simple Multi-Layer Perceptron (MLP) model using `brainstate`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7e020098",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:19.557867Z",
     "start_time": "2025-10-11T10:05:19.550147Z"
    }
   },
   "outputs": [],
   "source": [
    "class MLP(brainstate.nn.Module):\n",
    "    def __init__(self, din: int, dmid: int, dout: int):\n",
    "        super().__init__()\n",
    "        self.dense1 = brainstate.nn.Linear(din, dmid)\n",
    "        self.dense2 = brainstate.nn.Linear(dmid, dout)\n",
    "\n",
    "    def __call__(self, x: jax.Array) -> jax.Array:\n",
    "        x = self.dense1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.dense2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5bf157c",
   "metadata": {},
   "source": [
    "## Create the Model\n",
    "We create an instance of the model with a given seed for reproducibility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "39619169",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:20.454441Z",
     "start_time": "2025-10-11T10:05:19.562541Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  dense1=Linear(\n",
       "    in_size=(10,),\n",
       "    out_size=(20,),\n",
       "    w_mask=None,\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[20]),\n",
       "        'weight': ShapedArray(float32[10,20])\n",
       "      }\n",
       "    )\n",
       "  ),\n",
       "  dense2=Linear(\n",
       "    in_size=(20,),\n",
       "    out_size=(30,),\n",
       "    w_mask=None,\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[30]),\n",
       "        'weight': ShapedArray(float32[20,30])\n",
       "      }\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SEED = 42\n",
    "brainstate.random.seed(SEED)   # set seed in brainstate\n",
    "model1 = MLP(10, 20, 30)    # create model\n",
    "model1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26ded981",
   "metadata": {},
   "source": [
    "## Save the Model State\n",
    "\n",
    "### Save the Model State Using `orbax`\n",
    "We save the model's parameters to a checkpoint file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "14d1d552",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:22.597376Z",
     "start_time": "2025-10-11T10:05:22.428843Z"
    }
   },
   "outputs": [],
   "source": [
    "tmpdir = tempfile.mkdtemp()    # create temporary directory\n",
    "\n",
    "# Helper function to convert State objects to plain dictionaries for orbax\n",
    "def to_plain_dict(obj):\n",
    "    # Check if it's a dict-like object first\n",
    "    if isinstance(obj, dict):\n",
    "        return {k: to_plain_dict(v) for k, v in obj.items()}\n",
    "    # Try to access 'value' attribute safely\n",
    "    try:\n",
    "        if 'value' in dir(obj):\n",
    "            return to_plain_dict(obj.value)\n",
    "    except (TypeError, AttributeError):\n",
    "        pass\n",
    "    # Return as-is if it's a leaf value (array, number, etc.)\n",
    "    return obj\n",
    "\n",
    "# Save using orbax - convert to plain dict for compatibility\n",
    "state_nest = brainstate.graph.states(model1).to_nest()\n",
    "state_plain = to_plain_dict(state_nest)\n",
    "checkpointer = orbax.PyTreeCheckpointer()   # create checkpointer\n",
    "checkpointer.save(os.path.join(tmpdir, 'state'), state_plain)    # save state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27209868",
   "metadata": {},
   "source": [
    "Now, we've saved the model's parameters to the checkpoint files in `tmpdir/state` by using the `orbax` library."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb36ffc3",
   "metadata": {},
   "source": [
    "### Save the Model State Using `braintools`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2b03606b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:22.608235Z",
     "start_time": "2025-10-11T10:05:22.602453Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into C:\\Users\\Administrator\\AppData\\Local\\Temp\\tmpnjecqtgi\\state.msgpack\n"
     ]
    }
   ],
   "source": [
    "checkpoint = brainstate.graph.states(model1).to_nest()   # convert model to nest\n",
    "braintools.file.msgpack_save(os.path.join(tmpdir, 'state.msgpack'), checkpoint)    # save checkpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76030ac1",
   "metadata": {},
   "source": [
    "Now, we've saved the model's parameters to the checkpoint files in `tmpdir/state.msgpack` by using the `braintools` library."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6faf01ec",
   "metadata": {},
   "source": [
    "## Load the Model State\n",
    "\n",
    "### Load the Model State Using `orbax`\n",
    "Let's load the model's parameters from the checkpoint files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "26ba3c3e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:22.689503Z",
     "start_time": "2025-10-11T10:05:22.624261Z"
    }
   },
   "outputs": [],
   "source": [
    "# create a new model with the same structure\n",
    "brainstate.random.seed(0)\n",
    "model2 = MLP(10, 20, 30)\n",
    "\n",
    "# Load the parameters from checkpoint files using orbax\n",
    "checkpointer = orbax.PyTreeCheckpointer()\n",
    "restored_state = checkpointer.restore(os.path.join(tmpdir, 'state'))\n",
    "\n",
    "# Helper function to update model states from loaded dictionary\n",
    "def update_from_dict(model_dict, loaded_dict):\n",
    "    for key in model_dict:\n",
    "        if isinstance(model_dict[key], dict) and isinstance(loaded_dict.get(key), dict):\n",
    "            update_from_dict(model_dict[key], loaded_dict[key])\n",
    "        elif hasattr(model_dict[key], 'value'):\n",
    "            model_dict[key].value = loaded_dict[key]\n",
    "\n",
    "# Update the model with the loaded state\n",
    "model2_states = brainstate.graph.states(model2).to_nest()\n",
    "update_from_dict(model2_states, restored_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79929f4a",
   "metadata": {},
   "source": [
    "### Load the Model State Using `braintools`\n",
    "Let's load the model's parameters from the checkpoint files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a6d1de0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:22.776907Z",
     "start_time": "2025-10-11T10:05:22.693957Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint from C:\\Users\\Administrator\\AppData\\Local\\Temp\\tmpnjecqtgi\\state.msgpack\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'dense1': {'weight': ParamState(\n",
       "    value={\n",
       "      'bias': ShapedArray(float32[20]),\n",
       "      'weight': ShapedArray(float32[10,20])\n",
       "    }\n",
       "  )},\n",
       " 'dense2': {'weight': ParamState(\n",
       "    value={\n",
       "      'bias': ShapedArray(float32[30]),\n",
       "      'weight': ShapedArray(float32[20,30])\n",
       "    }\n",
       "  )}}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Create a model with the same structure.\n",
    "brainstate.random.seed(0)\n",
    "model3 = MLP(10, 20, 30)\n",
    "checkpoint = brainstate.graph.states(model3).to_nest()\n",
    "\n",
    "# Read the model parameters from the msgpack file\n",
    "braintools.file.msgpack_load(os.path.join(tmpdir, 'state.msgpack'), checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29dc37c9",
   "metadata": {},
   "source": [
    "## Demonstrate the Loaded Model\n",
    "Let's run the loaded model and check if it produces the same output as the original model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "dfe032ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:05:23.101307Z",
     "start_time": "2025-10-11T10:05:22.781916Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "y1 = model1(jnp.ones((1, 10)))\n",
    "y2 = model2(jnp.ones((1, 10)))\n",
    "y3 = model3(jnp.ones((1, 10)))\n",
    "print(jnp.allclose(y1, y2))    # True\n",
    "print(jnp.allclose(y1, y3))    # True"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
