{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f0b8c70a",
   "metadata": {},
   "source": [
    "# Save and Load Model States with `msgpack` Checkpointing\n",
    "\n",
    "This tutorial demonstrates how to use BrainTools' checkpointing system to save and restore model states using the msgpack format. The checkpointing system provides efficient serialization of complex neural network states, including weights, biases, and custom objects.\n",
    "\n",
    "## Core Functions\n",
    "\n",
    "- `msgpack_save(filename, data)` - Save data to msgpack file\n",
    "- `msgpack_load(filename, target=None)` - Load data from msgpack file\n",
    "- `msgpack_register_serialization(cls, to_dict, from_dict)` - Register custom serialization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b8c2b6e",
   "metadata": {},
   "source": [
    "## Basic Checkpointing\n",
    "\n",
    "### Simple Data Structures\n",
    "\n",
    "The most basic use case involves saving and loading simple data structures:"
   ]
  },
  {
   "cell_type": "code",
   "id": "1d6e1f0b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:31.158769Z",
     "start_time": "2025-09-24T03:21:31.107634Z"
    }
   },
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import braintools\n",
    "\n",
    "# Create some model data\n",
    "model_data = {\n",
    "    'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),\n",
    "    'bias': jnp.array([0.1, 0.2]),\n",
    "    'config': {\n",
    "        'learning_rate': 0.001,\n",
    "        'batch_size': 32\n",
    "    }\n",
    "}\n",
    "\n",
    "# Save to checkpoint\n",
    "checkpoint_path = \"checkpoints/model_checkpoint.msgpack\"\n",
    "braintools.file.msgpack_save(checkpoint_path, model_data)\n",
    "print(f\"Model saved to {checkpoint_path}\")\n",
    "\n",
    "# Load from checkpoint\n",
    "loaded_data = braintools.file.msgpack_load(checkpoint_path)\n",
    "print(\"Model loaded successfully!\")\n",
    "print(f\"Weights shape: {loaded_data['weights'].shape}\")\n",
    "print(f\"Config: {loaded_data['config']}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into model_checkpoint.msgpack\n",
      "Model saved to model_checkpoint.msgpack\n",
      "Loading checkpoint from model_checkpoint.msgpack\n",
      "Model loaded successfully!\n",
      "Weights shape: (2, 2)\n",
      "Config: {'learning_rate': 0.001, 'batch_size': 32}\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "cell_type": "markdown",
   "id": "7a5c4f8d",
   "metadata": {},
   "source": [
    "### Working with Templates\n",
    "\n",
    "For structured restoration, you can provide a template that defines the expected structure:"
   ]
  },
  {
   "cell_type": "code",
   "id": "2e4f9c0f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:31.242413Z",
     "start_time": "2025-09-24T03:21:31.195360Z"
    }
   },
   "source": [
    "# Create a template with the expected structure\n",
    "template = {\n",
    "    'weights': jnp.zeros((2, 2)),  # Shape and dtype information\n",
    "    'bias': jnp.zeros(2),\n",
    "    'config': {'learning_rate': 0.0, 'batch_size': 0}\n",
    "}\n",
    "\n",
    "# Load with template to ensure type safety\n",
    "loaded_data = braintools.file.msgpack_load(checkpoint_path, target=template)\n",
    "print(\"Loaded with template validation\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint from model_checkpoint.msgpack\n",
      "Loaded with template validation\n"
     ]
    }
   ],
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "id": "c9b5d8e7",
   "metadata": {},
   "source": [
    "## Working with BrainState Objects\n",
    "\n",
    "BrainTools provides special support for BrainState objects, which are commonly used in neural network implementations:"
   ]
  },
  {
   "cell_type": "code",
   "id": "4a1e2d3f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:31.310362Z",
     "start_time": "2025-09-24T03:21:31.269020Z"
    }
   },
   "source": [
    "import brainstate\n",
    "\n",
    "\n",
    "# Create BrainState objects\n",
    "class SimpleModel(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        self.weight = brainstate.ParamState(jnp.array([[1.0, 2.0], [3.0, 4.0]]))\n",
    "        self.bias = brainstate.ParamState(jnp.array([0.1, 0.2]))\n",
    "        self.running_mean = brainstate.State(jnp.array([0.0, 0.0]))\n",
    "\n",
    "\n",
    "# Initialize model\n",
    "model = SimpleModel()\n",
    "\n",
    "# Create checkpoint data\n",
    "checkpoint_data = {\n",
    "    'model_state': model.states(),\n",
    "    'training_step': 1000,\n",
    "    'epoch': 10\n",
    "}\n",
    "\n",
    "# Save checkpoint\n",
    "braintools.file.msgpack_save(\"checkpoints/model_state.msgpack\", checkpoint_data)\n",
    "\n",
    "# Create new model instance for loading\n",
    "new_model = SimpleModel()\n",
    "template = {\n",
    "    'model_state': new_model.states(),\n",
    "    'training_step': 0,\n",
    "    'epoch': 0\n",
    "}\n",
    "\n",
    "# Load and restore state\n",
    "restored_data = braintools.file.msgpack_load(\"checkpoints/model_state.msgpack\", target=template)\n",
    "print(f\"Restored to epoch {restored_data['epoch']}, step {restored_data['training_step']}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into model_state.msgpack\n",
      "Loading checkpoint from model_state.msgpack\n",
      "Restored to epoch 10, step 1000\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "cell_type": "markdown",
   "id": "8e5f6c9a",
   "metadata": {},
   "source": [
    "## Custom Object Serialization\n",
    "\n",
    "For custom objects, you can register serialization handlers:"
   ]
  },
  {
   "cell_type": "code",
   "id": "5b7d8e9f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:31.363373Z",
     "start_time": "2025-09-24T03:21:31.323906Z"
    }
   },
   "source": [
    "from typing import Dict, Any\n",
    "\n",
    "\n",
    "# Define a custom class\n",
    "class CustomLayer:\n",
    "    def __init__(self, input_size: int, output_size: int):\n",
    "        self.input_size = input_size\n",
    "        self.output_size = output_size\n",
    "        self.weights = brainstate.random.randn(input_size, output_size)\n",
    "        self.bias = jnp.zeros(output_size)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return jnp.dot(x, self.weights) + self.bias\n",
    "\n",
    "\n",
    "# Register serialization for CustomLayer\n",
    "def layer_to_state_dict(layer: CustomLayer) -> Dict[str, Any]:\n",
    "    return {\n",
    "        'input_size': layer.input_size,\n",
    "        'output_size': layer.output_size,\n",
    "        'weights': layer.weights,\n",
    "        'bias': layer.bias\n",
    "    }\n",
    "\n",
    "\n",
    "def layer_from_state_dict(\n",
    "    layer: CustomLayer,\n",
    "    state_dict: Dict[str, Any],\n",
    "    mismatch: str = 'error'\n",
    ") -> CustomLayer:\n",
    "    # Create new layer with restored parameters\n",
    "    new_layer = CustomLayer(state_dict['input_size'], state_dict['output_size'])\n",
    "    new_layer.weights = state_dict['weights']\n",
    "    new_layer.bias = state_dict['bias']\n",
    "    return new_layer\n",
    "\n",
    "\n",
    "# Register the serialization\n",
    "braintools.file.msgpack_register_serialization(\n",
    "    CustomLayer,\n",
    "    layer_to_state_dict,\n",
    "    layer_from_state_dict\n",
    ")\n",
    "\n",
    "# Now you can save and load CustomLayer objects\n",
    "layer = CustomLayer(10, 5)\n",
    "data = {'my_layer': layer, 'metadata': {'version': '1.0'}}\n",
    "\n",
    "braintools.file.msgpack_save(\"checkpoints/custom_layer.msgpack\", data)\n",
    "\n",
    "# Load with template\n",
    "template = {'my_layer': CustomLayer(10, 5), 'metadata': {'version': ''}}\n",
    "loaded = braintools.file.msgpack_load(\"checkpoints/custom_layer.msgpack\", target=template)\n",
    "print(f\"Loaded layer with shape: {loaded['my_layer'].weights.shape}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into custom_layer.msgpack\n",
      "Loading checkpoint from custom_layer.msgpack\n",
      "Loaded layer with shape: (10, 5)\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "id": "9a6e8f0b",
   "metadata": {},
   "source": [
    "## Mismatch Handling\n",
    "\n",
    "The checkpointing system provides flexible mismatch handling for cases where the saved state doesn't exactly match the target structure:"
   ]
  },
  {
   "cell_type": "code",
   "id": "3c4e5f6a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:31.474285Z",
     "start_time": "2025-09-24T03:21:31.383275Z"
    }
   },
   "source": [
    "# Create model with different structure\n",
    "original_model = {\n",
    "    'layer1': {'weights': jnp.ones((5, 3)), 'bias': jnp.zeros(3)},\n",
    "    'layer2': {'weights': jnp.ones((3, 2)), 'bias': jnp.zeros(2)},\n",
    "    'config': {'lr': 0.01}\n",
    "}\n",
    "\n",
    "braintools.file.msgpack_save(\"checkpoints/original.msgpack\", original_model)\n",
    "\n",
    "# New model with additional components\n",
    "new_model = {\n",
    "    'layer1': {'weights': jnp.zeros((5, 3)), 'bias': jnp.zeros(3)},\n",
    "    'layer2': {'weights': jnp.zeros((3, 2)), 'bias': jnp.zeros(2)},\n",
    "    'layer3': {'weights': jnp.zeros((2, 1)), 'bias': jnp.zeros(1)},  # New layer\n",
    "    'config': {'lr': 0.001, 'momentum': 0.9}  # New parameter\n",
    "}\n",
    "\n",
    "# Different mismatch handling strategies:\n",
    "\n",
    "# 1. Error on mismatch (default)\n",
    "try:\n",
    "    loaded = braintools.file.msgpack_load(\"checkpoints/original.msgpack\", target=new_model, mismatch='error')\n",
    "except ValueError as e:\n",
    "    print(f\"Error mode caught mismatch: {e}\")\n",
    "\n",
    "# 2. Warn on mismatch but continue\n",
    "import warnings\n",
    "\n",
    "with warnings.catch_warnings(record=True) as w:\n",
    "    warnings.simplefilter(\"always\")\n",
    "    loaded = braintools.file.msgpack_load(\"checkpoints/original.msgpack\", target=new_model, mismatch='warn')\n",
    "    if w:\n",
    "        print(f\"Warning: {w[0].message}\")\n",
    "    print(\"Loaded with warnings - missing components kept from target\")\n",
    "    print(f\"layer3 weights preserved: {jnp.allclose(loaded['layer3']['weights'], new_model['layer3']['weights'])}\")\n",
    "\n",
    "# 3. Ignore mismatches silently\n",
    "loaded = braintools.file.msgpack_load(\"checkpoints/original.msgpack\", target=new_model, mismatch='ignore')\n",
    "print(\"Loaded silently - new components preserved from target\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into original.msgpack\n",
      "Loading checkpoint from original.msgpack\n",
      "Error mode caught mismatch: The target dict keys and state dict keys do not match, target dict contains keys {'layer3'} which are not present in state dict at path .\n",
      "Loading checkpoint from original.msgpack\n",
      "Warning: The target dict keys and state dict keys do not match, target dict contains keys {'layer3'} which are not present in state dict at path .\n",
      "Loaded with warnings - missing components kept from target\n",
      "layer3 weights preserved: True\n",
      "Loading checkpoint from original.msgpack\n",
      "Loaded silently - new components preserved from target\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "cell_type": "markdown",
   "id": "a5d7e8c9",
   "metadata": {},
   "source": [
    "## Advanced Usage\n",
    "\n",
    "### Async Checkpointing\n",
    "\n",
    "For large models, you can use asynchronous checkpointing to avoid blocking training:"
   ]
  },
  {
   "cell_type": "code",
   "id": "b8f9c0ad",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:21:32.463061Z",
     "start_time": "2025-09-24T03:21:31.489383Z"
    }
   },
   "source": [
    "# Create async manager\n",
    "async_manager = braintools.file.AsyncManager(max_workers=2)\n",
    "\n",
    "# Large model data\n",
    "large_model = {\n",
    "    'embeddings': brainstate.random.randn(100000, 512),\n",
    "    'weights': [brainstate.random.normal(512, 512) for i in range(10)],\n",
    "    'step': 5000\n",
    "}\n",
    "\n",
    "# Save asynchronously\n",
    "braintools.file.msgpack_save(\"checkpoints/large_model.msgpack\", large_model, async_manager=async_manager)\n",
    "print(\"Checkpoint initiated asynchronously\")\n",
    "\n",
    "# Continue with other work...\n",
    "print(\"Doing other work while checkpoint saves...\")\n",
    "\n",
    "# Wait for completion when needed\n",
    "async_manager.wait_previous_save()\n",
    "print(\"Checkpoint completed!\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into large_model.msgpack\n",
      "Checkpoint initiated asynchronously\n",
      "Doing other work while checkpoint saves...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\codes\\projects\\braintools\\braintools\\file\\msg_checkpoint.py:650: UserWarning: The previous async brainpy.checkpoints.save has not finished yet. Waiting for it to complete before the next save.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint completed!\n"
     ]
    }
   ],
   "execution_count": 13
  },
  {
   "cell_type": "markdown",
   "id": "c7d8e9fa",
   "metadata": {},
   "source": [
    "### Working with BrainUnit Quantities\n",
    "\n",
    "BrainTools automatically handles BrainUnit quantities:"
   ]
  },
  {
   "cell_type": "code",
   "id": "e8b9c0db",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:22:29.462305Z",
     "start_time": "2025-09-24T03:22:29.429829Z"
    }
   },
   "source": [
    "import brainunit as u\n",
    "\n",
    "# Model with physical units\n",
    "physics_model = {\n",
    "    'time_constant': 10.0 * u.ms,\n",
    "    'resistance': 100.0 * u.mohm,\n",
    "    'capacitance': 200.0 * u.pF,\n",
    "    'voltage_threshold': -50.0 * u.mV\n",
    "}\n",
    "loaded_physics = physics_model.copy()\n",
    "\n",
    "braintools.file.msgpack_save(\"checkpoints/physics_model.msgpack\", physics_model)\n",
    "loaded_physics = braintools.file.msgpack_load(\"checkpoints/physics_model.msgpack\", target=loaded_physics)\n",
    "\n",
    "print(f\"Time constant: {loaded_physics['time_constant']}\")\n",
    "print(f\"Units preserved: {loaded_physics['time_constant'].unit}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into physics_model.msgpack\n",
      "Loading checkpoint from physics_model.msgpack\n",
      "Time constant: 10.0 * second\n",
      "Units preserved: 10.0^-3 * s\n"
     ]
    }
   ],
   "execution_count": 16
  },
  {
   "cell_type": "markdown",
   "id": "f8c9d0eb",
   "metadata": {},
   "source": [
    "### Handling Complex Nested Structures"
   ]
  },
  {
   "cell_type": "code",
   "id": "a9f8c0df",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:25:07.997570Z",
     "start_time": "2025-09-24T03:25:06.711177Z"
    }
   },
   "source": [
    "# Complex nested model\n",
    "complex_model = {\n",
    "    'encoder': {\n",
    "        'layers': [\n",
    "            {'weights': brainstate.random.randn(100, 64), 'bias': jnp.zeros(64)},\n",
    "            {'weights': brainstate.random.randn(4, 32), 'bias': jnp.zeros(32)}\n",
    "        ],\n",
    "        'config': {'activation': 'relu', 'dropout': 0.1}\n",
    "    },\n",
    "    'decoder': {\n",
    "        'layers': [\n",
    "            {'weights': brainstate.random.randn(32, 64), 'bias': jnp.zeros(64)},\n",
    "            {'weights': brainstate.random.randn(64, 100), 'bias': jnp.zeros(100)}\n",
    "        ],\n",
    "        'config': {'activation': 'sigmoid'}\n",
    "    },\n",
    "    'optimizer_state': {\n",
    "        'momentum': [jnp.zeros_like(w) for w in [\n",
    "            brainstate.random.randn(100, 64),\n",
    "            brainstate.random.randn(64, 32),\n",
    "            brainstate.random.randn(32, 64),\n",
    "            brainstate.random.randn(64, 100)\n",
    "        ]],\n",
    "        'learning_rate': 0.001,\n",
    "        'step': 1000\n",
    "    }\n",
    "}\n",
    "\n",
    "# Save complex model\n",
    "braintools.file.msgpack_save(\"checkpoints/complex_model.msgpack\", complex_model)\n",
    "\n",
    "# Load and verify structure\n",
    "loaded_complex = braintools.file.msgpack_load(\"checkpoints/complex_model.msgpack\")\n",
    "print(f\"Encoder layers: {len(loaded_complex['encoder']['layers'])}\")\n",
    "print(f\"Optimizer step: {loaded_complex['optimizer_state']['step']}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into complex_model.msgpack\n",
      "Loading checkpoint from complex_model.msgpack\n",
      "Encoder layers: 2\n",
      "Optimizer step: 1000\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "cell_type": "markdown",
   "id": "10a7d8ef",
   "metadata": {},
   "source": [
    "## Best Practices\n",
    "\n",
    "### 1. Version Compatibility\n",
    "\n",
    "Include version information in your checkpoints:"
   ]
  },
  {
   "cell_type": "code",
   "id": "11b6c9e8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:25:09.331495Z",
     "start_time": "2025-09-24T03:25:09.276387Z"
    }
   },
   "source": [
    "import braintools\n",
    "import time\n",
    "\n",
    "# Example training config (you would define this based on your needs)\n",
    "training_config = {'lr': 0.001, 'batch_size': 32}\n",
    "your_model_data = {'weights': jnp.ones((10, 10))}  # Example model data\n",
    "\n",
    "checkpoint_data = {\n",
    "    'model': your_model_data,\n",
    "    'metadata': {\n",
    "        'braintools_version': braintools.__version__,\n",
    "        'model_version': '1.2.0',\n",
    "        'timestamp': time.time(),\n",
    "        'training_config': training_config\n",
    "    }\n",
    "}\n",
    "\n",
    "braintools.file.msgpack_save(\"checkpoints/versioned_checkpoint.msgpack\", checkpoint_data)\n",
    "print(f\"Checkpoint saved with BrainTools version {braintools.__version__}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into versioned_checkpoint.msgpack\n",
      "Checkpoint saved with BrainTools version 0.0.12\n"
     ]
    }
   ],
   "execution_count": 18
  },
  {
   "cell_type": "markdown",
   "id": "12c7d8e9",
   "metadata": {},
   "source": [
    "### 2. Error Handling\n",
    "\n",
    "Always handle potential loading errors:"
   ]
  },
  {
   "cell_type": "code",
   "id": "13d6e8fa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:25:11.728556Z",
     "start_time": "2025-09-24T03:25:11.700626Z"
    }
   },
   "source": [
    "def safe_load_checkpoint(checkpoint_path, template=None):\n",
    "    try:\n",
    "        if template is not None:\n",
    "            return braintools.file.msgpack_load(checkpoint_path, target=template, mismatch='warn')\n",
    "        else:\n",
    "            return braintools.file.msgpack_load(checkpoint_path)\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Checkpoint {checkpoint_path} not found\")\n",
    "        return None\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "\n",
    "# Test the safe loading function\n",
    "result = safe_load_checkpoint(\"checkpoints/nonexistent_checkpoint.msgpack\")\n",
    "print(f\"Safe load result for missing file: {result}\")\n",
    "\n",
    "# Test with existing file\n",
    "result = safe_load_checkpoint(\"checkpoints/model_checkpoint.msgpack\")\n",
    "if result:\n",
    "    print(\"Successfully loaded existing checkpoint\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error loading checkpoint: Checkpoint not found: nonexistent_checkpoint.msgpack\n",
      "Safe load result for missing file: None\n",
      "Loading checkpoint from model_checkpoint.msgpack\n",
      "Successfully loaded existing checkpoint\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "cell_type": "markdown",
   "id": "14f5g6h7",
   "metadata": {},
   "source": [
    "### 3. Checkpoint Validation\n",
    "\n",
    "Validate critical components after loading:"
   ]
  },
  {
   "cell_type": "code",
   "id": "15h8i9j0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:25:13.429979Z",
     "start_time": "2025-09-24T03:25:13.400080Z"
    }
   },
   "source": [
    "def validate_checkpoint(loaded_data, expected_shapes):\n",
    "    for key, expected_shape in expected_shapes.items():\n",
    "        if key in loaded_data:\n",
    "            actual_shape = loaded_data[key].shape\n",
    "            if actual_shape != expected_shape:\n",
    "                raise ValueError(f\"Shape mismatch for {key}: expected {expected_shape}, got {actual_shape}\")\n",
    "        else:\n",
    "            raise ValueError(f\"Missing key in checkpoint: {key}\")\n",
    "    return True\n",
    "\n",
    "\n",
    "# Example validation\n",
    "expected_shapes = {\n",
    "    'weights': (2, 2),\n",
    "    'bias': (2,)\n",
    "}\n",
    "\n",
    "try:\n",
    "    loaded_data = braintools.file.msgpack_load(\"checkpoints/model_checkpoint.msgpack\")\n",
    "    validate_checkpoint(loaded_data, expected_shapes)\n",
    "    print(\"Checkpoint validation passed!\")\n",
    "except ValueError as e:\n",
    "    print(f\"Validation failed: {e}\")"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading checkpoint from model_checkpoint.msgpack\n",
      "Checkpoint validation passed!\n"
     ]
    }
   ],
   "execution_count": 20
  },
  {
   "cell_type": "code",
   "id": "17n4o5p6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-24T03:25:14.748484Z",
     "start_time": "2025-09-24T03:25:14.639119Z"
    }
   },
   "source": [
    "import os\n",
    "import jax\n",
    "\n",
    "\n",
    "def training_loop_with_checkpointing():\n",
    "    checkpoint_dir = \"checkpoints\"\n",
    "    os.makedirs(checkpoint_dir, exist_ok=True)\n",
    "\n",
    "    # Example training parameters\n",
    "    num_epochs = 100\n",
    "    checkpoint_interval = 10\n",
    "\n",
    "    # Mock model state and optimizer state\n",
    "    model_state = {'weights': jnp.ones((10, 10)), 'bias': jnp.zeros(10)}\n",
    "    optimizer_state = {'momentum': jnp.zeros((10, 10)), 'step': 0}\n",
    "\n",
    "    def train_step(model_state, batch):\n",
    "        # Mock training step\n",
    "        return model_state\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        # Mock batch\n",
    "        batch = jnp.ones((32, 10))\n",
    "\n",
    "        # Training step\n",
    "        model_state = train_step(model_state, batch)\n",
    "\n",
    "        # Checkpoint every N epochs\n",
    "        if epoch % checkpoint_interval == 0:\n",
    "            checkpoint_path = f\"{checkpoint_dir}/checkpoint_epoch_{epoch}.msgpack\"\n",
    "            checkpoint_data = {\n",
    "                'model_state': model_state,\n",
    "                'epoch': epoch,\n",
    "                'optimizer_state': optimizer_state,\n",
    "                'rng_state': jax.random.PRNGKey(42)  # In practice, use current RNG state\n",
    "            }\n",
    "            braintools.file.msgpack_save(checkpoint_path, checkpoint_data)\n",
    "            print(f\"Checkpoint saved at epoch {epoch}\")\n",
    "\n",
    "        # Break early for demo\n",
    "        if epoch >= 20:\n",
    "            break\n",
    "\n",
    "\n",
    "# Run the training loop\n",
    "training_loop_with_checkpointing()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into checkpoints/checkpoint_epoch_0.msgpack\n",
      "Checkpoint saved at epoch 0\n",
      "Saving checkpoint into checkpoints/checkpoint_epoch_10.msgpack\n",
      "Checkpoint saved at epoch 10\n",
      "Saving checkpoint into checkpoints/checkpoint_epoch_20.msgpack\n",
      "Checkpoint saved at epoch 20\n"
     ]
    }
   ],
   "execution_count": 21
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 4. Regular Checkpointing During Training",
   "id": "5f1682910a65957d"
  },
  {
   "cell_type": "markdown",
   "id": "18q7r8s9",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "BrainTools' checkpointing system provides a robust and flexible way to save and restore model states. Key features include:\n",
    "\n",
    "- **Automatic serialization** of JAX arrays, BrainState objects, and BrainUnit quantities\n",
    "- **Custom object support** through registration system\n",
    "- **Flexible mismatch handling** for evolving model architectures\n",
    "- **Asynchronous saving** for large models\n",
    "- **Type safety** through template-based loading\n",
    "\n",
    "This system enables reliable model persistence for long training runs, model deployment, and collaborative research workflows."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
