{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff7c015e07e0c6c9",
   "metadata": {},
   "source": [
    "# Graph and Node System\n",
    "\n",
    "In JAX, powerful function transformations are applicable to pytree data structures -- such as dictionaries, lists, tuples, and other tree-like forms. Leveraging pytree, JAX enables efficient automatic differentiation, vectorization, and parallelization. However, many practical applications, particularly complex neural network models and physical systems, are better suited to graph representations rather than simple tree structures. To address this need, the `brainstate` library adopts the graph-based representation in [flax](https://flax.readthedocs.io/) to represent complex model architectures.\n",
    "\n",
    "\n",
    "\n",
    "## What is `pygraph`?\n",
    "\n",
    "`pygraph` is a specialized data structure within `brainstate` designed to facilitate JAX transformations for graph-based models. Unlike conventional tree structures, graphs can articulate a broader range of node relationships and dependencies, making them ideal for more intricate model architectures. In scenarios where a model's state relies on the interactions among multiple nodes, graph structures offer enhanced flexibility in representation.\n",
    "\n",
    "The `pygraph` module utilizes **`brainstate.graph.Node`** as its foundational element, constructing graph structures by defining relationships (edges) between nodes. Each node can hold arbitrary pytree array data or nested `pygraph` substructures, promoting a more adaptable and modular approach to graph construction.\n",
    "\n",
    "We can think of `brainstate.graph.Node` as a **container class**, where its attributes denote its leaf nodes. These containers can reference one another, thereby facilitating the creation of complex graph structures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "21085866d34afe68",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.229813Z",
     "start_time": "2025-10-10T15:54:05.432653Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:00:58.101854Z",
     "iopub.status.busy": "2026-05-30T17:00:58.101557Z",
     "iopub.status.idle": "2026-05-30T17:01:00.441148Z",
     "shell.execute_reply": "2026-05-30T17:01:00.440473Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import brainstate\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "brainstate.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a69bd0578a8467",
   "metadata": {},
   "source": [
    "## Basic `brainstate.graph.Node` Example\n",
    "\n",
    "Let's start by defining a simple class that inherits from `brainstate.graph.Node`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77f9ae362b940fe1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.506364Z",
     "start_time": "2025-10-10T15:54:07.238063Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:00.443755Z",
     "iopub.status.busy": "2026-05-30T17:01:00.443353Z",
     "iopub.status.idle": "2026-05-30T17:01:00.766768Z",
     "shell.execute_reply": "2026-05-30T17:01:00.765928Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node A created:\n",
      "  w shape: (2, 3)\n",
      "  b shape: (3,)\n",
      "  b type: ShortTermState\n"
     ]
    }
   ],
   "source": [
    "class A(brainstate.graph.Node):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.random.rand(2, 3)\n",
    "        self.b = brainstate.ShortTermState(brainstate.random.rand(3))\n",
    "\n",
    "# Create an instance\n",
    "a = A()\n",
    "\n",
    "print(\"Node A created:\")\n",
    "print(f\"  w shape: {a.w.shape}\")\n",
    "print(f\"  b shape: {a.b.value.shape}\")\n",
    "print(f\"  b type: {type(a.b).__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b58124d7b328ac",
   "metadata": {},
   "source": [
    "In the code above, we define a class `A` that inherits from `brainstate.graph.Node`. Within the class's initialization function, we define two attributes, `w` and `b`, which represent a randomly generated 2×3 matrix and a randomly generated array of length 3, respectively. Both of these data structures are JAX pytree structures.\n",
    "\n",
    "### Cyclic References\n",
    "\n",
    "It is noteworthy that **`brainstate.graph.Node` allows us to create cyclic references**. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9270a4ee302d629e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.540157Z",
     "start_time": "2025-10-10T15:54:07.537307Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:00.768193Z",
     "iopub.status.busy": "2026-05-30T17:01:00.768043Z",
     "iopub.status.idle": "2026-05-30T17:01:00.771280Z",
     "shell.execute_reply": "2026-05-30T17:01:00.770785Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created cyclic reference: a.self = a\n",
      "Are they the same object? True\n"
     ]
    }
   ],
   "source": [
    "# Create a cyclic reference\n",
    "a.self = a\n",
    "\n",
    "print(\"Created cyclic reference: a.self = a\")\n",
    "print(f\"Are they the same object? {a.self is a}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e66524c99968322",
   "metadata": {},
   "source": [
    "We set the `self` attribute of `a` to `a`, thereby creating a cyclic reference. This type of referencing is **not permitted in tree structures**, but it is valid in graph structures. Such cyclic references can be used to represent complex model architectures, such as the recurrent connections in recurrent neural networks. Through this flexible referencing method, we can more naturally express the dynamic relationships between nodes in neural networks.\n",
    "\n",
    "### Inspecting the Graph Structure\n",
    "\n",
    "We can inspect the graph definition using `brainstate.graph.graphdef()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c572e954c97bd3da",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.575041Z",
     "start_time": "2025-10-10T15:54:07.568785Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:00.772857Z",
     "iopub.status.busy": "2026-05-30T17:01:00.772726Z",
     "iopub.status.idle": "2026-05-30T17:01:00.778862Z",
     "shell.execute_reply": "2026-05-30T17:01:00.778233Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GraphDef(\n",
       "  root=NodeEdge(index=0),\n",
       "  node_specs=(NodeSpec(type=<class '__main__.A'>, index=0, metadata=(<class '__main__.A'>,), fields=(('b', StateEdge(index=1, path=('b',), type=<class 'brainstate.ShortTermState'>)), ('self', NodeEdge(index=0)), ('w', StaticEdge(value=Array([[0.72766423, 0.78786755, 0.18169427],\n",
       "         [0.26263022, 0.11072934, 0.20263076]], dtype=float32))))),)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the graphdef of a\n",
    "graphdef = brainstate.graph.graphdef(a)\n",
    "graphdef"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f60e6d557d37001",
   "metadata": {},
   "source": [
    "## Building Hierarchical Neural Networks\n",
    "\n",
    "`brainstate.graph.Node` is the foundational class in `pygraph`, defining the structure and behavior of each node within the graph. The **attributes** of each node represent its leaf nodes and can accommodate any JAX array or other transformable data types. The **methods** associated with each node facilitate operations for manipulating, updating, and transforming the data contained within.\n",
    "\n",
    "In `brainstate`, any neural network module is a subclass of `brainstate.graph.Node`, allowing for the flexible construction of various neural network architectures. This design ensures excellent scalability while fully harnessing the powerful computational capabilities provided by JAX. For instance, we can easily define a multi-layer perceptron (MLP) module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a41868f76bec2452",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:08.564682Z",
     "start_time": "2025-10-10T15:54:07.598547Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:00.780670Z",
     "iopub.status.busy": "2026-05-30T17:01:00.780431Z",
     "iopub.status.idle": "2026-05-30T17:01:04.462987Z",
     "shell.execute_reply": "2026-05-30T17:01:04.461977Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP structure:\n",
      "MLP(\n",
      "  l1=Linear(\n",
      "    in_size=(2,),\n",
      "    out_size=(3,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[3]),\n",
      "        'weight': ShapedArray(float32[2,3])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l2=Linear(\n",
      "    in_size=(3,),\n",
      "    out_size=(4,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[4]),\n",
      "        'weight': ShapedArray(float32[3,4])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l3=Linear(\n",
      "    in_size=(4,),\n",
      "    out_size=(5,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[5]),\n",
      "        'weight': ShapedArray(float32[4,5])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  st=ShortTermState(\n",
      "    value=ShapedArray(float32[5])\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MLP(brainstate.graph.Node):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.l1 = brainstate.nn.Linear(2, 3)\n",
    "        self.l2 = brainstate.nn.Linear(3, 4)\n",
    "        self.l3 = brainstate.nn.Linear(4, 5)\n",
    "        self.st = brainstate.ShortTermState(brainstate.random.rand(5))\n",
    "\n",
    "# Create MLP instance\n",
    "mlp = MLP()\n",
    "\n",
    "print(\"MLP structure:\")\n",
    "print(mlp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8757f357fcd874b4",
   "metadata": {},
   "source": [
    "## Traversing the Graph Structure\n",
    "\n",
    "Nodes can be created and manipulated using the following attributes and methods:\n",
    "\n",
    "- **Data Storage**: Each `brainstate.graph.Node` can store any JAX arrays or other transformable data types within its *attributes*.\n",
    "\n",
    "- **Node Connections**: The *attributes* of a `brainstate.graph.Node` can reference other `brainstate.graph.Node` instances, establishing complex dependency graphs -- as illustrated by the three linear layer modules within the `MLP` class above.\n",
    "\n",
    "- **Attributes and Their Paths**: Every `brainstate.graph.Node` includes a unique path for retrieval and transformation. This path aids in identifying and accessing nodes within complex structures, representing the node's position within the hierarchical nesting of the graph.\n",
    "\n",
    "### Iterating Over Leaf Nodes\n",
    "\n",
    "For instance, we can view the leaf data points in the `MLP` graph using the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dc8d3b5e89e9863c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.396842Z",
     "start_time": "2025-10-10T15:54:10.392984Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.465648Z",
     "iopub.status.busy": "2026-05-30T17:01:04.465132Z",
     "iopub.status.idle": "2026-05-30T17:01:04.469830Z",
     "shell.execute_reply": "2026-05-30T17:01:04.468981Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Leaf nodes in MLP:\n",
      "============================================================\n",
      "('l1', '_in_size', 0)                    2\n",
      "('l1', '_out_size', 0)                   3\n",
      "('l1', 'weight')                         ParamState\n",
      "('l2', '_in_size', 0)                    3\n",
      "('l2', '_out_size', 0)                   4\n",
      "('l2', 'weight')                         ParamState\n",
      "('l3', '_in_size', 0)                    4\n",
      "('l3', '_out_size', 0)                   5\n",
      "('l3', 'weight')                         ParamState\n",
      "('st',)                                  ShortTermState\n"
     ]
    }
   ],
   "source": [
    "print(\"Leaf nodes in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "for path, leaf in brainstate.graph.iter_leaf(mlp):\n",
    "    leaf_info = leaf.__class__.__name__ if isinstance(leaf, brainstate.State) else leaf\n",
    "    print(f\"{str(path):<40} {leaf_info}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9b026896800605e",
   "metadata": {},
   "source": [
    "### Iterating Over All Nodes\n",
    "\n",
    "We can also view all nodes in the `MLP` graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "417155450f649750",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.424489Z",
     "start_time": "2025-10-10T15:54:10.418566Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.471467Z",
     "iopub.status.busy": "2026-05-30T17:01:04.471338Z",
     "iopub.status.idle": "2026-05-30T17:01:04.474662Z",
     "shell.execute_reply": "2026-05-30T17:01:04.473969Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "All nodes in MLP:\n",
      "============================================================\n",
      "('l1',)                        Linear\n",
      "('l2',)                        Linear\n",
      "('l3',)                        Linear\n",
      "()                             MLP\n"
     ]
    }
   ],
   "source": [
    "print(\"\\nAll nodes in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "for path, node in brainstate.graph.iter_node(mlp):\n",
    "    print(f\"{str(path):<30} {node.__class__.__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cefa4a93599dbcb0",
   "metadata": {},
   "source": [
    "## Common Functions in `pygraph`\n",
    "\n",
    "`brainstate` provides numerous utilities for operating on `pygraph`, which can be found in the `brainstate.graph` module.\n",
    "\n",
    "Overall, the commonly used functions in `pygraph` can be categorized into the following types:\n",
    "\n",
    "### Graph Structure Operations\n",
    "Functions for constructing and retrieving graph structures:\n",
    "- `brainstate.graph.graphdef`: View the graph structure\n",
    "- `brainstate.graph.iter_node`: Iterate through all child nodes in the graph structure\n",
    "- `brainstate.graph.iter_leaf`: Traverse all data points in the graph structure\n",
    "- `brainstate.graph.nodes`: View all nodes in the graph structure\n",
    "- `brainstate.graph.states`: View all `State` instances in the graph structure\n",
    "\n",
    "### Graph Structure Transformations\n",
    "Functions for transforming and manipulating graph structures:\n",
    "- `brainstate.graph.treefy_states`: Convert all `State` instances in the graph structure to pytree\n",
    "- `brainstate.graph.clone`: Copy the graph structure\n",
    "- `brainstate.graph.treefy_split`: Split the graph structure into a `graphdef` and a pytree representation of `State`\n",
    "- `brainstate.graph.treefy_merge`: Merge the `graphdef` and pytree representation of `State` into a graph structure\n",
    "- `brainstate.graph.flatten`: Flatten the graph structure into collections of `graphdef` and `State`\n",
    "- `brainstate.graph.unflatten`: Restore the flattened graph structure to its original form\n",
    "\n",
    "### Graph Structure Modifications\n",
    "Functions for modifying and updating graph structures:\n",
    "- `brainstate.graph.pop_states`: Remove `State` instances from the graph structure that meet certain conditions\n",
    "- `brainstate.graph.update_states`: Update `State` instances in the graph structure that meet certain conditions\n",
    "\n",
    "### Graph Structure Conversions\n",
    "Functions for converting between pygraph and pytree data structures:\n",
    "- `brainstate.graph.graph_to_tree`: Convert a graph structure to a pytree\n",
    "- `brainstate.graph.tree_to_graph`: Convert a pytree to a graph structure\n",
    "\n",
    "These functions provide developers with a rich set of tools to achieve flexibility and efficiency when constructing and manipulating graph structures."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c60fb47f5a07efd",
   "metadata": {},
   "source": [
    "## Splitting and Merging Graphs\n",
    "\n",
    "A fundamental operation is splitting a graph into its structure definition (`graphdef`) and its state values (`treefy_states`). This is crucial for working with JAX transformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d0fbdc61e856ba1b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.448799Z",
     "start_time": "2025-10-10T15:54:10.442142Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.476564Z",
     "iopub.status.busy": "2026-05-30T17:01:04.476372Z",
     "iopub.status.idle": "2026-05-30T17:01:04.479947Z",
     "shell.execute_reply": "2026-05-30T17:01:04.479212Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph split into:\n",
      "  1. graphdef (structure)\n",
      "  2. tree_states (state values)\n",
      "\n",
      "Number of states: 4\n"
     ]
    }
   ],
   "source": [
    "# Split the graph structure into graphdef and treefy_states\n",
    "graphdef, tree_states = brainstate.graph.treefy_split(mlp)\n",
    "\n",
    "print(\"Graph split into:\")\n",
    "print(f\"  1. graphdef (structure)\")\n",
    "print(f\"  2. tree_states (state values)\")\n",
    "print(f\"\\nNumber of states: {len(tree_states)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2ebfd319743f6d58",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.473793Z",
     "start_time": "2025-10-10T15:54:10.466509Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.482380Z",
     "iopub.status.busy": "2026-05-30T17:01:04.482128Z",
     "iopub.status.idle": "2026-05-30T17:01:04.486826Z",
     "shell.execute_reply": "2026-05-30T17:01:04.486163Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph definition:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GraphDef(\n",
       "  root=NodeEdge(index=0),\n",
       "  node_specs=(NodeSpec(type=<class '__main__.MLP'>, index=0, metadata=(<class '__main__.MLP'>,), fields=(('l1', NodeEdge(index=1)), ('l2', NodeEdge(index=3)), ('l3', NodeEdge(index=5)), ('st', StateEdge(index=7, path=('st',), type=<class 'brainstate.ShortTermState'>)))), NodeSpec(type=<class 'brainstate.nn.Linear'>, index=1, metadata=(<class 'brainstate.nn.Linear'>,), fields=(('_in_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=2)),))), ('_name', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('_out_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=3)),))), ('w_mask', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('weight', StateEdge(index=2, path=('l1', 'weight'), type=<class 'brainstate.ParamState'>)))), NodeSpec(type=<class 'brainstate.nn.Linear'>, index=3, metadata=(<class 'brainstate.nn.Linear'>,), fields=(('_in_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=3)),))), ('_name', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('_out_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=4)),))), ('w_mask', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('weight', StateEdge(index=4, path=('l2', 'weight'), type=<class 'brainstate.ParamState'>)))), NodeSpec(type=<class 'brainstate.nn.Linear'>, index=5, metadata=(<class 'brainstate.nn.Linear'>,), fields=(('_in_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=4)),))), ('_name', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('_out_size', PytreeEdge(metadata=PyTreeDef((*,)), fields=((0, StaticEdge(value=5)),))), ('w_mask', PytreeEdge(metadata=PyTreeDef(None), fields=())), ('weight', StateEdge(index=6, path=('l3', 'weight'), type=<class 'brainstate.ParamState'>)))))\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the graphdef\n",
    "print(\"Graph definition:\")\n",
    "print(\"=\" * 60)\n",
    "graphdef"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4126740c3a85f543",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.501202Z",
     "start_time": "2025-10-10T15:54:10.493332Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.488824Z",
     "iopub.status.busy": "2026-05-30T17:01:04.488592Z",
     "iopub.status.idle": "2026-05-30T17:01:04.494736Z",
     "shell.execute_reply": "2026-05-30T17:01:04.494228Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Tree states:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{\n",
       "  'l1': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[-0.42749512, -0.46094188, -1.1391693 ],\n",
       "               [-1.5929017 ,  2.1525893 ,  0.7923302 ]], dtype=float32),\n",
       "        'bias': Array([0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x75bd922bd950>,\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'l2': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[ 0.5922906 , -0.44910267, -0.2513227 ,  1.5940351 ],\n",
       "               [ 0.5387001 , -1.538589  , -1.8159095 , -1.4499966 ],\n",
       "               [-1.0424379 , -1.6111814 ,  1.5264777 ,  1.1270016 ]],      dtype=float32),\n",
       "        'bias': Array([0., 0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x75bd9093d5b0>,\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'l3': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[ 0.90832496, -0.60686237,  0.29780853, -0.90121883,  0.36897153],\n",
       "               [ 0.19299707,  0.06321475, -0.4093656 , -0.25156304,  0.58528906],\n",
       "               [ 0.08484001,  0.17979762,  0.39775968,  1.2874871 ,  0.64356524],\n",
       "               [-0.87185615, -0.48536623,  1.1163356 ,  1.2303296 ,  0.23622379]],      dtype=float32),\n",
       "        'bias': Array([0., 0., 0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x75bd9093cb00>,\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'st': TreefyState(\n",
       "    type=<class 'brainstate.ShortTermState'>,\n",
       "    value=Array([0.7928349 , 0.14884543, 0.46262634, 0.9012923 , 0.20398748],      dtype=float32),\n",
       "    _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x75bd9092b0b0>,\n",
       "    tag=None\n",
       "  )\n",
       "}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the tree states\n",
    "print(\"\\nTree states:\")\n",
    "print(\"=\" * 60)\n",
    "tree_states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6711263fa16c3034",
   "metadata": {},
   "source": [
    "### Merging Back\n",
    "\n",
    "We can merge the `graphdef` and `tree_states` back into a complete graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "916d6b14d5004a24",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.542506Z",
     "start_time": "2025-10-10T15:54:10.537923Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.496339Z",
     "iopub.status.busy": "2026-05-30T17:01:04.496197Z",
     "iopub.status.idle": "2026-05-30T17:01:04.499564Z",
     "shell.execute_reply": "2026-05-30T17:01:04.498959Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Merged MLP:\n",
      "MLP(\n",
      "  l1=Linear(\n",
      "    in_size=(2,),\n",
      "    out_size=(3,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[3]),\n",
      "        'weight': ShapedArray(float32[2,3])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l2=Linear(\n",
      "    in_size=(3,),\n",
      "    out_size=(4,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[4]),\n",
      "        'weight': ShapedArray(float32[3,4])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l3=Linear(\n",
      "    in_size=(4,),\n",
      "    out_size=(5,),\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[5]),\n",
      "        'weight': ShapedArray(float32[4,5])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  st=ShortTermState(\n",
      "    value=ShapedArray(float32[5])\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Merge the graphdef structure and treefy_states\n",
    "mlp2 = brainstate.graph.treefy_merge(graphdef, tree_states)\n",
    "\n",
    "print(\"Merged MLP:\")\n",
    "print(mlp2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c8f14de00956ed",
   "metadata": {},
   "source": [
    "## Collecting States\n",
    "\n",
    "We can collect all states or specific types of states from a graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d28ce6624733022a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.568611Z",
     "start_time": "2025-10-10T15:54:10.564664Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.501345Z",
     "iopub.status.busy": "2026-05-30T17:01:04.501136Z",
     "iopub.status.idle": "2026-05-30T17:01:04.504843Z",
     "shell.execute_reply": "2026-05-30T17:01:04.504067Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All states in MLP:\n",
      "============================================================\n",
      "('l1', 'weight')               ParamState\n",
      "('l2', 'weight')               ParamState\n",
      "('l3', 'weight')               ParamState\n",
      "('st',)                        ShortTermState\n",
      "\n",
      "Total states: 4\n"
     ]
    }
   ],
   "source": [
    "# View all states in the graph structure\n",
    "states = brainstate.graph.states(mlp2)\n",
    "\n",
    "print(\"All states in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in states.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")\n",
    "\n",
    "print(f\"\\nTotal states: {len(states)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17d6f5468409d56d",
   "metadata": {},
   "source": [
    "## `filter` Filter Syntax\n",
    "\n",
    "It is noteworthy that most graph operation functions for retrieving `State` instances support the inclusion of a series of **`filter`** functions to select `State` instances that meet certain conditions. For example, to filter out all `ShortTermState` instances, you can use the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8d40272b2916a52",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.595632Z",
     "start_time": "2025-10-10T15:54:10.591320Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.506480Z",
     "iopub.status.busy": "2026-05-30T17:01:04.506344Z",
     "iopub.status.idle": "2026-05-30T17:01:04.510525Z",
     "shell.execute_reply": "2026-05-30T17:01:04.509702Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ShortTermState instances:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{\n",
       "  ('st',): ShortTermState(\n",
       "    value=ShapedArray(float32[5])\n",
       "  )\n",
       "}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Filter only ShortTermState instances\n",
    "short_term_states = brainstate.graph.states(mlp2, brainstate.ShortTermState)\n",
    "\n",
    "print(\"ShortTermState instances:\")\n",
    "print(\"=\" * 60)\n",
    "short_term_states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9572255d91de1f2",
   "metadata": {},
   "source": [
    "### Filter DSL (Domain-Specific Language)\n",
    "\n",
    "In general, a `filter` has the following function form:\n",
    "\n",
    "```python\n",
    "def predicate(path: tuple[Key, ...], value: Any) -> bool:\n",
    "    ...\n",
    "```\n",
    "\n",
    "Here, `Key` is a hashable and comparable data type (often a string), `path` is a tuple of `Key` representing the nested structure corresponding to `value`, and `value` is the value at that path. If the value should be included in the `filter`, the function returns `True`; otherwise, it returns `False`.\n",
    "\n",
    "However, to simplify the creation of `filter` functions, `brainstate` provides a small domain-specific language (DSL). This allows users to pass types, boolean values, ellipses, tuples/lists, etc., which are internally converted into the corresponding predicates.\n",
    "\n",
    "| Literal | Description |\n",
    "|---------|-------------|\n",
    "| `...` or `True` | Matches all values |\n",
    "| `None` or `False` | Matches no values |\n",
    "| `type` | Matches instances of type `type`, or values with a `type` attribute of `type` |\n",
    "| `'{filter}'` (str) | Matches values with a string `tag` attribute equal to `'{filter}'` |\n",
    "| `(*filters)` (tuple) or `[*filters]` (list) | Matches values satisfying any of the internal `filters` |\n",
    "\n",
    "For example, we can use the `filter` DSL to select all `ParamState` instances and other remaining `State` instances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e5e14c2fa1dc9a6c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.622972Z",
     "start_time": "2025-10-10T15:54:10.612465Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.512363Z",
     "iopub.status.busy": "2026-05-30T17:01:04.512154Z",
     "iopub.status.idle": "2026-05-30T17:01:04.516276Z",
     "shell.execute_reply": "2026-05-30T17:01:04.515579Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ParamState instances:\n",
      "============================================================\n",
      "('l1', 'weight')               ParamState\n",
      "('l2', 'weight')               ParamState\n",
      "('l3', 'weight')               ParamState\n",
      "\n",
      "Other states:\n",
      "============================================================\n",
      "('st',)                        ShortTermState\n"
     ]
    }
   ],
   "source": [
    "# Split states into params and others\n",
    "params, others = brainstate.graph.states(mlp, brainstate.ParamState, ...)\n",
    "\n",
    "print(\"ParamState instances:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in params.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")\n",
    "\n",
    "print(\"\\nOther states:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in others.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e9e8b84efaa6e4",
   "metadata": {},
   "source": [
    "## `pygraph` and JAX Transformations\n",
    "\n",
    "After converting `pygraph` to `pytree`, we can utilize JAX's function transformation capabilities to operate on it. For example, we can use JAX's `jit` function to compile models, use the `grad` function for automatic differentiation, and use the `vmap` function for batching operations.\n",
    "\n",
    "Let's demonstrate this with a complete training example. We'll define a simple MLP model with a counter to track how many times it's been called."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c168978ecb23519f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.932142Z",
     "start_time": "2025-10-10T15:54:10.660584Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.518131Z",
     "iopub.status.busy": "2026-05-30T17:01:04.517907Z",
     "iopub.status.idle": "2026-05-30T17:01:04.856761Z",
     "shell.execute_reply": "2026-05-30T17:01:04.855902Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model created and split into:\n",
      "  - graphdef (structure)\n",
      "  - params_ (trainable parameters)\n",
      "  - counts_ (counter state)\n"
     ]
    }
   ],
   "source": [
    "# Define Linear layer\n",
    "class Linear(brainstate.nn.Module):\n",
    "    def __init__(self, din: int, dout: int):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(brainstate.random.randn(din, dout) * 0.1)\n",
    "        self.b = brainstate.ParamState(jnp.zeros((dout,)))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.w.value + self.b.value\n",
    "\n",
    "\n",
    "# Define custom Count state\n",
    "class Count(brainstate.State):\n",
    "    pass\n",
    "\n",
    "\n",
    "# Define MLP with counter\n",
    "class TrainableMLP(brainstate.graph.Node):\n",
    "    def __init__(self, din, dhidden, dout):\n",
    "        super().__init__()\n",
    "        self.count = Count(jnp.array(0))\n",
    "        self.linear1 = Linear(din, dhidden)\n",
    "        self.linear2 = Linear(dhidden, dout)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        self.count.value += 1\n",
    "        x = self.linear1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# Create model and split it\n",
    "model = TrainableMLP(din=1, dhidden=32, dout=1)\n",
    "graphdef, params_, counts_ = brainstate.graph.treefy_split(\n",
    "    model, brainstate.ParamState, Count\n",
    ")\n",
    "\n",
    "print(\"Model created and split into:\")\n",
    "print(f\"  - graphdef (structure)\")\n",
    "print(f\"  - params_ (trainable parameters)\")\n",
    "print(f\"  - counts_ (counter state)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9670aed54d1882",
   "metadata": {},
   "source": [
    "### Creating a Dataset\n",
    "\n",
    "We'll create a simple regression dataset: $y = 0.8x^2 + 0.1 + \\text{noise}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ccfaad421f26c87a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:11.539566Z",
     "start_time": "2025-10-10T15:54:10.938175Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:04.858978Z",
     "iopub.status.busy": "2026-05-30T17:01:04.858720Z",
     "iopub.status.idle": "2026-05-30T17:01:05.597128Z",
     "shell.execute_reply": "2026-05-30T17:01:05.596113Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAArMAAAHaCAYAAAAT9MEUAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAW7NJREFUeJzt3Xl8VPW9//H3JJNZMiQkhKwQCaAWUkVrUIrW6wYXtdDr9WrdCkhd6kJbpXWhLqhUsdVaeyuVuiDtvW1pa63XItoiSntVfiKboIhKIaIGQoCECZPJTDLz/f3BzciQBGaS2U7yej4ePHzkcM6Z75xvRj7zOZ/v59iMMUYAAACABWWlewAAAABATxHMAgAAwLIIZgEAAGBZBLMAAACwLIJZAAAAWBbBLAAAACyLYBYAAACWRTALAAAAyyKYBQAAgGURzAIAAMCyCGYBABkhEAjom9/8po466ijl5+fry1/+slauXJnuYQHIcASzAICM0N7erqqqKr3++utqamrSTTfdpClTpmj//v3pHhqADEYwC2SwRYsWyWazqba2NqXHAung8Xh0991366ijjlJWVpYuvfRSORwOffDBB+keWr+yf/9+ZWVl6ac//Wm6hwLEhGAWiIPNZovpz4oVK9I91LTqCKQ7/rhcLlVUVGjSpEn6z//8TzU3N/f43G+++abuueceNTU1JW7AfWAsPRUIBHTbbbepoqJCbrdb48aN07Jly2I69qOPPtKll16qoUOHKjc3V6NGjdJ9992nlpaWhIzto48+0t69e3X00Ucn5Hzd6c012L9/v+bMmaNzzz1XgwYNks1m06JFi5I63mR79913ZYzRmDFj0j0UICY2Y4xJ9yAAq/jv//7vqJ9//etfa9myZfqv//qvqO0TJ05UaWlpr18vFAqpra1NTqdTNpstZcf21qJFizRjxgzdd999Gj58uNra2rRz506tWLFCy5Yt01FHHaUXXnihR/9YPvzww7rlllu0bds2VVVVJX7wFh1LT1122WV69tlnddNNN+mYY47RokWL9Pbbb+u1117TV77ylW6P++STTzRmzBgNHDhQ1113nQYNGqSVK1dq0aJF+trXvqb/+Z//6dW4/H6/zjzzTJ1//vmaM2dOr851JD29BpJUW1ur4cOH66ijjtKIESO0YsUKPfPMM7ryyiuTOuZkCofDCgaDafl/B9AjBkCP3XjjjSaej9H+/fuTOJrM8cwzzxhJ5u233+70d8uXLzdut9sMGzbMtLS0xH3uhx56yEgy27ZtS8BIeyeTxtITb731lpFkHnroocg2v99vRo4cacaPH3/YY++//34jybz77rtR26dNm2Ykmb179/Z4XMFg0Hz1q181l19+uQmHw3Ede8YZZ5jp06fHvH9vroExxrS2tpodO3YYY4x5++23jSTzzDPPxDXmDvGOHcABlBkASXLPPffIZrNp06ZNuvzyy1VYWKivfOUr+vjjj3XDDTfoC1/4gtxut4qKinTxxRd3Wdt6aN1rxzm3bNmiK6+8UgUFBRo4cKBmzJjR6dZub46VpBUrVmjs2LFyuVwaOXKkfvnLX0bO0Rtnn3227rrrLn388cdRme5Yrss999yjW265RZI0fPjwSBlDbW1tXNe1ublZN910k6qqquR0OlVSUqKJEydq7dq1kX0+++wzffOb31RpaamcTqe++MUvauHChTGNpcPmzZu1ffv2I16T888/v8vMrjFGJ510kk4//fQjnqMnnn32WWVnZ+vaa6+NbHO5XLrqqqu0cuVKffLJJ90e6/V6JanTHYjy8nJlZWXJ4XBEtn322WdyuVz65je/GbXvK6+8opycHN18882RbeFwWFOnTpXNZtOvfvWrpGcGe3MNJMnpdKqsrCypYzycs846S//yL/+itWvX6rzzzlNeXp6GDBmin/3sZ532Xbp0qc4880zl5+eroKBAV1xxhXbv3t1pv4kTJ+q0006L/Lx9+3ZdffXVGjFihFwul8rKynT++edry5YtUcetWLFC559/vgoKCjRo0CBNnjxZ//znPxP/poFD2NM9AKCvu/jii3XMMcfogQcekDFGb7/9tt58881IrWFtba0ef/xxnXnmmdq0aZNyc3OPeM6vf/3rGj58uObNm6e1a9fqqaeeUklJiX70ox8l5Nh169bp3HPPVXl5ue69916FQiHdd999Ki4u7tW16DB16lT94Ac/0N/+9jddc801khTTdbnwwgv14Ycf6ne/+51++tOfavDgwZKk4uJivfTSSzFf1+uuu07PPvusZs6cqerqau3Zs0evv/663n//fZ100kmqr6/Xl7/8ZdlsNs2cOTNy/quuukper1c33XTTYcfSYfTo0TrjjDOOWEN98skn66WXXlJjY6MKCwsj2xcvXqx169bp9ddf73RMW1ub9u3bF9P1HjRokLKyOucu1q1bp2OPPVb5+flR20855RRJ0vr161VZWdnlOc8880z96Ec/0lVXXaV7771XRUVFevPNN/X444/rO9/5jjweT2TfIUOG6Oqrr9YTTzyhOXPmaNiwYdq8ebMuvvhinXfeefrJT34S2fdb3/qWduzYob/+9a+y25P/T1RvrkEm2LhxoyoqKjRlyhTNmDFDF1xwgZ588kndfPPNOvvss3X88cdL+rwk5oILLtBDDz2kTz/9VD/96U+1Y8cOvfrqq1Hn3LBhgy688EJJUn19vcaOHauioiJdc801Kikp0SeffKI///nPUfOzaNEiXXXVVZo4caJ++MMfqqWlRT//+c81YcIEbdq0SW63O3UXBf1PulPDgJUdrsxgzpw5RpK57LLLorZ3dWt95cqVRpL59a9/HbW943Z9x23sjnN+85vfjNrv3//9301RUVHCjp0yZYrJzc01n332WWTbRx99ZOx2e0xlFYcrM+gwcOBA86UvfSnyc6zXpbtb+/Fc14EDB5obb7yx27FdddVVpry83OzevTtq+6WXXmoGDhwYea0jlRlIMmeccUa3r9PhhRdeMJLM8uXLI9uCwaAZOXKkmTJlSpfHvPbaa0ZSTH+6G98Xv/hFc/bZZ3fa/t577xlJZsGCBYcd99y5c43b7Y56rTvuuKPLfT/99FPjdDrN9ddfb3bv3m1GjhxpTjzxxKjSm9raWiPJuFwu4/F4In/+8Y9/HHYcB4v3Vn1vr8HBUl1mUFdXZySZ4uJi88knn0S2b9q0yUgyv/rVr4wxxvzjH/8wNpvN3H333VHHz58/30gyq1atimyrr683kswvfvELY4wxDz/8sHG5XGbfvn3djmPjxo3G4XCYuXPnRm1/5513jCTzpz/9Keb3BPQEmVkgya677rqonw/OULS1tcnr9eroo49WQUGB1q5dq6lTp8Z9ztNPP11//vOf5fV6O2WY4j02FArplVde0b//+7+roqIist/RRx+t8847T3/5y1+OOL5YDBgwIKqrQW+vSzzHFxQU6K233lJdXV3Ue5QO3Nr/05/+pK9//esyxkTdhp00aZIWL16stWvXRt2G7Y6JcX3tySefLElau3atzj77bEnSE088oW3btun555/v8pgTTjgh5hX33d0G9/v9cjqdnba7XK7I3x9OVVWV/uVf/kX/8R//oaKiIr344ot64IEHVFZWppkzZ0btO2TIEF1zzTV68skntXbtWvn9fv3973+PyuAOGzYs5msmdZ2dbmtrUyAQ6HT7vLvsdG+vQU8lYuwbN26UJM2ZM0dDhw6NbM/JyZGkSKnH/fffryFDhujuu++OOr7jd3jz5s2R38ENGzZIUiSj29TUpLa2Nq1du1Znnnlml+/lvvvu01FHHaVvfetbUWOvqKhQTk6Otm7depgrAfQewSyQZMOHD4/62e/3a968eXrmmWf02WefRf3jHett46OOOirq545b042NjUcMZo907K5du+T3+7tsh5TIFkn79+9XSUlJ5OfeXpd4jv/xj3+s6dOnq7KyUjU1NTr//PM1bdo0jRgxQg0NDWpqatITTzyhJ554osvX2rVrV7xv97DKyso0ZMgQrVu3TpLk8/k0d+5cfeMb39Bxxx3X5TGFhYWaMGFCr17X7XYrEAh02t7a2hr5++4sXrxY1157rT788MNIIHXhhRcqHA7rtttu02WXXaaioqKoY77//e/rscce04YNG/S///u/GjJkSK/G/8Ybb+iss87qtP3NN9/U4sWLo7Z113GiN9egNxIx9o5g9oILLojavnnzZknSF77wBbW0tGj58uWaNWuWsrOzo/bz+XySFFXf3HHOjmB22rRpevrpp3XWWWfppJNO0qWXXqpvfOMbKi8vl3SgrdmLL76olpaWqM/zwfLy8rrcDiQKwSyQZIf+Y/jtb39bzzzzjG666SaNHz9eAwcOlM1m06WXXqpwOBzTOQ/9R6lDLFmt3hybKJ9++qn27dsXFRz39rrEc/zXv/71SEb6b3/7mx566CH96Ec/0nPPPacvfelLkqRvfOMbmj59epevlYz+myeffHIkmH3kkUfU2Nio++67r9v9g8Gg9u7dG9O5i4uLu5z38vJyffbZZ52279ixQ5I6Za0P9otf/EJf+tKXojKCkvS1r31NixYt0rp16zoF2/fff7+kA0/6GjRoUExjP5yustPf+973VFZWFlmc16G77HRvrkFvJGLsGzZsiHwROtg777wju92u6upqbd68We3t7frCF77Q6fiOBVzV1dVR5zzqqKM0cOBASdIxxxyjDz/8UM8995z+8pe/6K677tJ9992nl19+Waeddpq2bt2qlpYWzZ07V1/+8pe7fa9AMhHMAin27LPPavr06VGLXlpbWzOm8X5JSYlcLlenlcqSutzWEx19eSdNmhTZFut16W51e7zXtby8XDfccINuuOEG7dq1SyeddJLuv/9+/f3vf1deXp5CodARM5+JXGl/8skn64UXXtD27dv18MMP6/rrr9ewYcO63f/NN9/sMrPXle4yeyeeeKJee+21TuUpb731VuTvu1NfXx+1WK1DW1ubpAMB68EeeughPfXUU3rsscd0yy236P7779dTTz0V0/i701V2urCwUOXl5TFnrXtzDXojEWPfuHFjl4Hihg0bdOyxx0b1ie2qlOLpp5/WsGHDorL/GzZs6PRlbcCAAZo2bZqmTZumDz74QMcff7z+8Ic/6LTTTouUCo0ePbrXdwqAnqI1F5Bi2dnZnbKgP//5zxUKhdI0omjZ2dmaMGGCnn/+edXV1UW2b9myRS+99FKvz//qq69q7ty5Gj58uK644oqo143lunTUWB4apMZ6fCgU6lR2UFJSooqKCgUCAWVnZ+s//uM/9Kc//Unvvvtup/E3NDQccSwdYm3NJUljx45VOBzW5ZdfLmOM7rjjjsPu35HZi+VPd5m9iy66SKFQKKqcIhAI6JlnntG4ceOiVvG3tLRo8+bNkZrIY489VuvWrdOHH34Ydc7f/e53ysrKigqInn/+ed1+++2aO3eubrzxRl177bX69a9/rW3btsV0bZKpN9cgnUKhkN5///0ug9l33nkncv1Hjhwpu92uN954I2qfP/zhD/rHP/6h2267LRLwhkIhbdq0KXJsV+/T7XYrFApFMtZVVVWy2Wz605/+1Gnf9vZ2NTY29u6NAjEgMwuk2OTJk/Vf//VfGjhwoKqrq7Vy5Uq98sorneoL0+mee+7R3/72N5122mm6/vrrFQqF9Nhjj+m4447T+vXrYz7PSy+9FLnNWV9fr1dffVXLli3TsGHD9MILL0QW2UixX5eamhpJ0h133KFLL71UOTk5mjJlSszHNzc3a+jQobrooot0wgknaMCAAXrllVf09ttvR7K6Dz74oF577TWNGzdO11xzjaqrq7V3716tXbtWr7zySuT2fndj6QhyY23NJR0IZqUDtZT33HPPEdugJaJmdty4cbr44os1e/Zs7dq1S0cffbR+9atfqba2Vk8//XTUvqtWrdJZZ52lOXPmRHrsvvTSSzr99NM1c+ZMFRUVacmSJXrppZd09dVXR4KdNWvW6IorrtAVV1wRCdBvvfVWLViwICHZ2d7qzTXo8Nhjj6mpqSny5e8vf/mLPv30U0kHyl86btkn0kcffaTW1tZOwazf79eWLVsiJTIej0ff+ta3tGDBAtntdo0ZM0ZvvfWWFi5cqKlTp+r666/vdM6Oetlvf/vbevfdd/W1r31Nw4cP144dO/TEE09o6NChkZZ6JSUluuyyy/Tb3/5WXq9X5513nkKhkLZs2aLnnntOixcvPuJT1IBeS0sPBaCPiKU1V0NDQ9T2xsZGM2PGDDN48GAzYMAAM2nSJLN582YzbNiwTm15umuvdeg5D92vt8cac+BJXV/60peMw+EwI0eONE899ZT53ve+Z1wu1xGvS8c5O/44HA5TVlZmJk6caH72s58Zr9fb6Zh4rsvcuXPNkCFDTFZWVmTssR4fCATMLbfcYk444QSTl5dnPB6POeGEEyKtiDrU19ebG2+80VRWVpqcnBxTVlZmzjnnHPPEE08ccSwdFGNrrg5VVVWmuLjYNDc3x3xMb/n9fvP973/flJWVGafTaU4++WTz8ssvd9qvoxXYnDlzItveeustc95555mysjKTk5Njjj32WHP//febtrY2Y4wxn3zyiSkvLzennXaaaW1tjTrf9ddfb3JycszWrVsT+n568hSt3lwDY4wZNmxY3G3Rejv2P/zhD10+gW3VqlVGklmyZElkW0tLi5k5c6YZPHiwcbvd5oQTTjCPP/54p6erdZzzvffeM8YY8/TTT5uzzz7bFBcXG6fTaY4++mjzne98x+zatSvquNbWVvPAAw+Y6upq43a7TVFRkTn55JPNnDlzjM/ni/n9Az1lMyaFqz4AWNoFF1yg9957Tx999FG6h9LnbN26Vccee6weeeQRfec730n3cADAMqiZBdClQ/trfvTRR5HHYSLxZs+eraqqqk59gAEAh0fNLIAujRgxQldeeaVGjBihjz/+WI8//rgcDoduvfXWdA+tz2hqatJLL72kFStW6I9//KNeeumlqJ6fAIAjI5gF0KVzzz1Xv/vd77Rz5045nU6NHz9eDzzwgI455ph0D63PWL58uS6//HINHTpUv/zlL6NalQEAYkPNLAAAACyLmlkAAABYFsEsAAAALItgFgAAAJbV7xaAhcNh1dXVKS8vL6HPVQcAAEBiGGPU3NysiooKZWUdPvfa74LZurq6qGdtAwAAIDN98sknGjp06GH36XfBbF5enqQDFyc/Pz/prxcOh9XQ0KDi4uIjfrNAZmIOrY85tD7m0NqYP+tL9Rx6vV5VVlZG4rbD6XfBbEdpQX5+fsqC2dbWVuXn5/MBtijm0PqYQ+tjDq2N+bO+dM1hLCWh/EYBAADAsghmAQAAYFkEswAAALAsglkAAABYFsEsAAAALItgFgAAAJZFMAsAAADLIpgFAACAZRHMAgAAwLIIZgEAAGBZ/e5xtgAAAJmm0RdUk79NBe4cFXoc6R6OpRDMAgAApElrW0hLNtRpdW2jWoLtynXYNbaqUJPHVMiVk53u4VkCZQYAAABpsmRDnZZtqleWzaaKAreybDYt21SvJRvq0j00yyCYBQAASINGX1CraxtV5HGqOM8ppz1bxXlOFXmcWlPbqEZfMN1DtASCWQAAgDRo8repJdiufHd01We+2y5fsF1N/rY0jcxaCGYBAADSoMCdo1yHXV5/e9R2r79dHoddBe6cNI3MWghmAQAA0qDQ49DYqkLt8QXU0BxQoD2khuaA9vgCqqkqpKtBjOhmAAAAkCaTx1RIktbUNqquyS+Pw66J1aWR7TgyglkAAIA0ceVk66KaSp0zqpQ+sz1EMAsAAJBmhR4HQWwPUTMLAAAAyyKYBQAAgGWlNZj9xz/+oSlTpqiiokI2m03PP//8EY9ZsWKFTjrpJDmdTh199NFatGhR0scJAACAzJTWYNbn8+mEE07Q/PnzY9p/27Zt+upXv6qzzjpL69ev10033aSrr75af/3rX5M8UgAAAGtq9AW1bbevzz5RLK0LwM477zydd955Me+/YMECDR8+XD/5yU8kSaNHj9brr7+un/70p5o0aVKXxwQCAQUCgcjPXq9XkhQOhxUOh3sx+tiEw2EZY1LyWkgO5tD6mEPrYw6tjflLj9a2kJZu3KHVHzfKH2yX22HX2GGFOv/4crlysuM6V6rnMJ7XsVQ3g5UrV2rChAlR2yZNmqSbbrqp22PmzZune++9t9P2hoYGtba2JnqInYTDYe3bt0/GGGVlUaJsRcyh9TGH1sccWhvzlx5vbGnQO580aZAzR7n52WoJ+PXOR17ZWvfptKOL4zpXquewubk55n0tFczu3LlTpaWlUdtKS0vl9Xrl9/vldrs7HTN79mzNmjUr8rPX61VlZaWKi4uVn5+f9DGHw2HZbDYVFxfzAbYo5tD6mEPrYw6tjflLvSZfUG/X1yvLOVA5A5xqk5STI4UU0Or6sM46oUAFcbQCS/UculyumPe1VDDbE06nU06ns9P2rKyslH2gbDZbSl8PicccWh9zaH3MobUxf6m1LxBSSzCkigK3ZLNFtue7c1TX5Ne+QEiD8uKbi1TOYTyvYanfqLKyMtXX10dtq6+vV35+fpdZWQAAgP6owJ2jXIddXn971Havv10eh10F7pw0jSzxLBXMjh8/XsuXL4/atmzZMo0fPz5NIwIAAMg8hR6HxlYVao8voIbmgALtITU0B7THF1BNVWGfetpYWoPZ/fv3a/369Vq/fr2kA6231q9fr+3bt0s6UO86bdq0yP7XXXedtm7dqltvvVWbN2/WL37xC/3hD3/QzTffnI7hAwAAZKzJYyo0sbpUxhjVNflljNHE6lJNHlOR7qElVFprZlevXq2zzjor8nPHQq3p06dr0aJF2rFjRySwlaThw4frxRdf1M0336yf/exnGjp0qJ566qlu23IBAAD0V66cbF1UU6lzRpWqyd+mAndOn8rIdkhrMHvmmWfKGNPt33f1dK8zzzxT69atS+KoAAAA+o5Cj6NPBrEdLFUzCwAAAByMYBYAAACWRTALAAAAyyKYBQAAgGURzAIAAMCyCGYBAABgWQSzAAAAsCyCWQAAAFgWwSwAAAAsi2AWAAAAlkUwCwAAAMsimAUAAIBlEcwCAADAsghmAQAAYFkEswAAALAsglkAAABYFsEsAAAALItgFgAAAJZFMAsAAIDDavIFVe9tVZMvmO6hdGJP9wAAAACQmVrbQlqyoU6ra/cqN7RfLRv3aWzVIE0eUyFXTna6hyeJzCwAAAC6sWRDnZZtqleWzaZBHoeybDYt21SvJRvq0j20CIJZAAAAdNLoC2p1baOKPE4VD3AqJytLxQOcKvI4taa2UY0ZUnJAMAsAAIBOmvxtagm2K98dXZWa77bLF2xXk78tTSOLRjALAACATgrcOcp12OX1t0dt9/rb5XHYVeDOSdPIohHMAgAAoJNCj0Njqwq1xxdQw/6A2sJhNewPaI8voJqqQhV6HOkeoiS6GQAAAKAbk8dUSJLW1O7VXl9QJtuhidWlke2ZgGAWAAAAXXLlZOuimkqdfWyxPtmxU5XlZRqU50r3sKIQzAIAAOCwCjwOBfNdKsiQ0oKDUTMLAAAAyyKYBQAAgGVRZgAAAPqtRl9QTf42FbhzEr46P5nnxucIZgEAQL/T2hbSkg11Wl3bqJZgu3Iddo2tKtTkMRVy5WRn7LnRGWUGAACg31myoU7LNtUry2ZTRYFbWTablm2q15INdRl9bnRGMAsAAPqVRl9Qq2sbVeRxqjjPKac9W8V5ThV5nFpT26hGXzAjz42uEcwCAIB+pcnfppZgu/Ld0dWW+W67fMF2NfnbMvLc6BrBLAAA6FcK3DnKddjl9bdHbff62+Vx2FXgzsnIc6NrBLMAAKBfKfQ4NLaqUHt8ATU0BxRoD6mhOaA9voBqqgp71XkgmedG1+hmAAAA+p3JYyokSWtqG1XX5JfHYdfE6tLI9kw9NzojmAUAAP2OKydbF9VU6pxRpTH3go21b2xPzo2eI5gFAACWlIiHEhR6HEc8tqd9Y2M5N3qPYBYAAFhKqh9K0NE3tsjjVEWBW15/u5ZtqpckXVRTmfDXQ3xYAAYAACwllQ8loG9s5iOYBQAAlpHq4JK+sZmPYBYAAFhGqoNL+sZmPoJZAABgGakOLukbm/kIZgEAgGWkI7icPKZCE6tLZYxRXZNfxhj6xmYQuhkAAABLSfVDCegbm9kIZgEAgKWkK7ikb2xmIpgFAACWRHAJiZpZAAAAWBjBLAAAACyLYBYAAACWlfZgdv78+aqqqpLL5dK4ceO0atWqw+7/6KOP6gtf+ILcbrcqKyt18803q7W1NUWjBQAA6BsafUFt2+2LempaV9syXVoXgP3+97/XrFmztGDBAo0bN06PPvqoJk2apA8++EAlJSWd9v/tb3+r22+/XQsXLtSpp56qDz/8UFdeeaVsNpseeeSRNLwDAAAAa2ltC2nJhjqtrm1US7BduQ67TqgcKMmmdz5pimwbW1WoyWMq5MrJTveQDyutmdlHHnlE11xzjWbMmKHq6motWLBAubm5WrhwYZf7v/nmmzrttNN0+eWXq6qqSv/6r/+qyy677IjZXAAAABywZEOdlm2qV5bNpooCt7JsNv3mre36zf/7OGrbsk31WrKhLt3DPaK0ZWaDwaDWrFmj2bNnR7ZlZWVpwoQJWrlyZZfHnHrqqfrv//5vrVq1Sqeccoq2bt2qpUuXaurUqd2+TiAQUCAQiPzs9XolSeFwWOFwOEHvpnvhcFjGmJS8FpKDObQ+5tD6mENrY/4yR5MvqNW1e1Xkcah4wIG2ZgOc2WoNtMtmk/Kc2XJmZ/3f3xmtqd2rs48tVr7bntI5jOd10hbM7t69W6FQSKWlpVHbS0tLtXnz5i6Pufzyy7V792595StfkTFG7e3tuu666/SDH/yg29eZN2+e7r333k7bGxoaUlJrGw6HtW/fPhljlJWV9hJl9ABzaH3MofUxh9bG/GWOem+rckP7NcjjUE6oTZIUbm9XladdkuRsb1Zu1oHwsMIZ1l5fUJ/s2KniAY6UzmFzc3PM+1rqoQkrVqzQAw88oF/84hcaN26ctmzZou9+97uaO3eu7rrrri6PmT17tmbNmhX52ev1qrKyUsXFxcrPz0/6mMPhsGw2m4qLi/kAWxRzaH3MofUxh9bG/CVeky+optY2FbhyVBDHgyMcnqBaNu5Ta8Cm4gFOSVLAHlKtr0U2mzTSnidb9oEa2QZ/QCbbocryMuW77SmdQ5fLFfO+aQtmBw8erOzsbNXX10dtr6+vV1lZWZfH3HXXXZo6daquvvpqSdLxxx8vn8+na6+9VnfccUeXF9fpdMrpdHbanpWVlbIPlM1mS+nrIfGYQ+tjDq2PObQ25i8xulq8Fc9CrUF5Lo2tGqRlm+ol2ZTvtmt/ICSX0y4ZqTkQki3LJq+/XXt8QU2sLtWgPFfkC0mq5jCe10jbb5TD4VBNTY2WL18e2RYOh7V8+XKNHz++y2NaWlo6vbns//v2YIxJ3mABAAAyQFeLt+JdqDV5TIUmVpfKGKO6Jr+MMbpi3FG64svDorZNrC7V5DEVSXw3iZHWMoNZs2Zp+vTpGjt2rE455RQ9+uij8vl8mjFjhiRp2rRpGjJkiObNmydJmjJlih555BF96UtfipQZ3HXXXZoyZUokqAUAAOiLGn1Bra5tVJHHqeK8A3edi/MOxD9raht1zqhSFcZQcuDKydZFNZU6Z1SpmvxtKnDnRI4794tlnbZlurQGs5dccokaGhp09913a+fOnTrxxBP18ssvRxaFbd++PSoTe+edd8pms+nOO+/UZ599puLiYk2ZMkX3339/ut4CAABASjT529QSbFdFgTtqe77brromv5r8bXEFoIUeR6f9u9qW6dK+AGzmzJmaOXNml3+3YsWKqJ/tdrvmzJmjOXPmpGBkAAAAmaPAnaNch11ef3skIytJXn+7PA67Ctw5aRxd+lCFDQAAYAGFHofGVhVqjy+ghuaAAu0hNTQHtMcXUE1VoeUyqomS9swsAAAAYtOxIGtNbaPqmvzyOOyWWaiVLASzAAAAFnG4xVv9FcEsAACAxVhxoVayEMwCAACkUKMvSFY1gQhmAQAAUqC3T+9C1+hmAAAAkAKJeHoXOiOYBQAASLJDn97ltGerOM+pIo9Ta2ob1egLJuQ1tu32JeRcVkKZAQAAQJIl+uldB+vv5QtkZgEAAJLs4Kd3HSwRT+/q7+ULBLMAAABJlqynd6WifCHTEcwCAACkwOQxFZpYXSpjjOqa/DLG9PrpXR3lC/nu6MrRfLddvmC7mvxtvR12xqNmFgAAIAWS8fSug8sXivM+r49NRPmCVZCZBQAAfV68K/3j2T/ecxd6HBo+2JOQByYkq3zBSsjMAgCAPivelf7x7J8pXQQ6yhTW1Daqrskvj8Pe6/IFKyGYBQAAfVbHSv8ij1MVBW55/e1atqleknRRTWWv9o/33MmSjPIFK6HMAAAA9EnxrvSPZ/9UdRGIp4QhkeULVkJmFgAA9EnxPqggnv2T+RAEKXNKGKyAzCwAAOiT4n1QQTz7J/MhCBIPQogHwSwAAEiaRl9QtXt82t+a+n6n8a70j2f/ZHYR4EEI8aHMAAAAJNzBt8n9wTZVONs0crfR5BOGZPRK/3j2T1YXgWSXMPQ1BLMAACDhDl7pXz7QLVtrm155v16y2TJ6pX88+yeriwAPQogPwSwAAEioQ2+TyxjlunM0SDlaU9uoc0aVpjyzWOhxxPWa8ewf77ljOd/YqsJIm69894HAdo8voInVqb92mY6aWQAAkFAdt8nz3dE5s3yXXb5gu5r8qa+ftZrJYyo0sbpUxhjVNflljOlXD0KIB5lZAACQUN3eJm/lNnms+vuDEOJBZhYAACRUVyv99/nbtDcBK/37m/76IIR4kJkFAAAJd/BK/x37/KpwShNGc5sciUcwCwAAEu7g2+SNLQGFfU0acdQQZWVxUxiJRTALAACSptDj0EC3XbtCvnQPBX0UX48AAABgWQSzAAAAsCyCWQAAAFgWNbMAACAhGn3BXvdETcQ50L8QzAIAgLgcGnC2toW0ZEOdVtc2qiXYrlyHXWOrCjV5TIVcOdlHPqGUkHN0Nz70bQSzAAAgJt0FnG0hoxUf7FKRx6mKAre8/nYt21QvSbqopjKmcy/ZUKdlm+p7dY5EBsSwDmpmAQBAlxp9QW3b7VOjLyjp84Azy2ZTRYFbWTabXtywQ8+v+1RFHqeK85xy2rNVnOdUkcepNbWNkWOP9Dqraxt7dY7uxrdsU72WbKjr1XVAZiMzCwAAonSV4Rxdnqd3P/NGAk5JKs7Llre1Tf9s2K/R5QOjzpHvtquuya8mf5sGug8fbjT529QSbFdFgbvbcxypXODQgLhjfNKBp5CdM6qUkoM+iswsAACI0l2G86Ndzco/JDAdPOBAgLhnfyBqu9ffLo/DrgJ3zhFfr8Cdo1yHXV5/e4/P0REQHzq+fLddvmC7mvxtRzwHrIlgFgAARHR3y780zyWvv127m6OD1mC7UXm+W/sD7WpoDijQHlJDc0B7fAHVVBXGlA0t9Dg0tqpQe3yBHp8jEQExrIlgFgAARHSX4SzOdyrfbddOb+eA84KTKvTVMeUyxqiuyS9jjCZWl2rymIqYX3fymApNrC7t8TkSERDDmqiZBQAAEQdnODtqTqUDGc5jSgbo+CEFen+HV3VNfnkc9kjA6crJ1jmjSnvcEsuVk62Laip7dY6OwHdNbWOn8aHvIpgFAAARHRnOjrZY+e4Dge0eX0ATq0t1UU1lt31cCz2OXmdAe3OORATEsB6CWQAAEOVIGc5EBK3JlOnjQ2IRzAIAgChkOGElBLMAAKBLZDhhBQSzAAAkUHf1pACSg2AWAIAE6OqpWWOrCiMr/QEkB31mAQBIgO6emrVkQ126hwb0aQSzAAD0UndPzSryOLWmtlGNvmC6hwj0WQSzAAD0UndPzcp32+ULtqvJ35amkQF9H8EsAAC9dPBTsw7m9bfL47CrwJ2TppEBfR/BLAAAvdTx1Kw9voAamgMKtIfU0BzQHl9ANVWFdDUAkijtwez8+fNVVVUll8ulcePGadWqVYfdv6mpSTfeeKPKy8vldDp17LHHaunSpSkaLQAAXZs8pkITq0tljFFdk1/GmKinZgFIjrS25vr973+vWbNmacGCBRo3bpweffRRTZo0SR988IFKSko67R8MBjVx4kSVlJTo2Wef1ZAhQ/Txxx+roKAg9YMHAOAgPDUreejdi8NJazD7yCOP6JprrtGMGTMkSQsWLNCLL76ohQsX6vbbb++0/8KFC7V37169+eabysk5UH9UVVWVyiEDAHBYmf7ULCsFhvTuRSzSFswGg0GtWbNGs2fPjmzLysrShAkTtHLlyi6PeeGFFzR+/HjdeOON+p//+R8VFxfr8ssv12233abs7K5/qQOBgAKBQORnr9crSQqHwwqHwwl8R10Lh8MyxqTktZAczKH1MYfWxxz2XmtbSEs37tDqjxvlD7bL7bBr7LBCnX98edIDw57O35J3PtMr79drkMepioEueVvb9cqmnZIxuvCkoUkaLbqS6s9gPK+TtmB29+7dCoVCKi0tjdpeWlqqzZs3d3nM1q1b9eqrr+qKK67Q0qVLtWXLFt1www1qa2vTnDlzujxm3rx5uvfeezttb2hoUGtra+/fyBGEw2Ht27dPxhhlZaW9RBk9wBxaH3Nofcxh772xpUHvfNKkQc4c5eZnqyXg1zsfeWVr3afTji5O6mv3ZP72t7bpn9vrNMIjDXS3SWpToVsqUEj/3F6nrYNtGuCiS0SqpPoz2NzcHPO+lnqcbTgcVklJiZ544gllZ2erpqZGn332mR566KFug9nZs2dr1qxZkZ+9Xq8qKytVXFys/Pz8lIzZZrOpuLiY/wFbFHNofcyh9TGHvdPkC+rt+nplOQcqZ4BTbZJycqSQAlpdH9ZZJxSoIIklBz2Zv5Y9PtUFdql8oFstB919DbtC2rHPryxPgUqKPMkaMg6R6s+gy+WKed+0BbODBw9Wdna26uvro7bX19errKysy2PKy8uVk5MTVVIwevRo7dy5U8FgUA5H5w+i0+mU0+nstD0rKytl/0O02WwpfT0kHnNofcyh9TGHPbcvEFJLMKSKArdks0W257tzVNfk175ASIPykntd452/wlyn3I4ceVtDKs77PFzxtoaU68hRYa6T34UUS+VnMJ7XSNtvgcPhUE1NjZYvXx7ZFg6HtXz5co0fP77LY0477TRt2bIlqo7iww8/VHl5eZeBLAAAyMyHOjT6gtq229fto37p3YtYpbXMYNasWZo+fbrGjh2rU045RY8++qh8Pl+ku8G0adM0ZMgQzZs3T5J0/fXX67HHHtN3v/tdffvb39ZHH32kBx54QN/5znfS+TYAAMhoHYHhsk0H7obmuw8Etnt8AU2sLk1pYBhPh4KOHr1rahtV1+SXx2Gndy86SWswe8kll6ihoUF33323du7cqRNPPFEvv/xyZFHY9u3bo9LMlZWV+utf/6qbb75ZY8aM0ZAhQ/Td735Xt912W7reAgD0OVZq3YTYZUpguGRDnZZtqleRx6mKAre8/vZIkH1RTWXUvvTuRSxsxhiT7kGkktfr1cCBA7Vv376ULQDbtWuXSkpKqO2xKObQ+pjD2GRyT0/mMHHS8WWlY/4cngL9+G8fKstmU3He5+tZGpoDMsbo1nNHEaxmqFR/BuOJ1/g/AgBA0ucZsyybTRUFbmXZbFq2qV5LNtSle2hIoEKPQ8MHe9ISNDa1tqkl2K58d/SN4Xy3Xb5gu5r8bSkfE6yPYBYAoEZfUKtrG1Xkcao4zymnPVvFeU4VeZxaU9vY7SKd/uxIC5jQWYEr8xaiwfos1WcWAJAcTf4DGbOKAnfU9ny3XXVNfjX527j9+38yuRwj0xVk0EI09B1kZgEAGdm6KVNRjtE7k8dUaGJ1qYwxqmvyyxhDhwL0CplZAEBGtW7KZIeWY0hScd6BbOya2kadM4prdSR0KECikZkFAEgiYxaLjnIMFjD1XjoXoqFvITMLAJBExiwWB5djdGRkJcoxgHQiMwsAiELGrHs8YhXIPGRmAQCIQ6Y8SQvAAQSzAADEgXIMILMQzAIA0AOFHgdBLJABqJkFAACAZRHMAgAAwLIoMwAAwCIafUHqdIFDEMwCAJDhWttCWrKhTqtrG9USbFeuw66xVYWaPKZCrpzsI58A6MMoMwAAIMMt2VCnZZvqlWWzqaLArSybTcs21WvJhrp0Dw1IO4JZAABSoNEX1LbdPjX6gnEft7q2UUUep4rznHLas1Wc51SRx6k1tY1xnw/oaygzAAAgiXpbItDkb1NLsF0VBe6o7fluu+qa/Gryt1E/i36NzCwAAEnU2xKBAneOch12ef3tUdu9/nZ5HHYVuHOSMWzAMghmAQBp19Nb8JkuESUChR6HxlYVao8voIbmgALtITU0B7THF1BNVSFZWfR7lBkAANKmr6/ST1SJwOQxFZKkNbWNqmvyy+Owa2J1aWQ70J8RzAJAAtD/s2c6bsEXeZyqKHDL62/Xsk31kqSLairTPLreO7hEoDjv8+A83hIBV062Lqqp1DmjSvk9Aw5BMAsAvdDXM4vJdOgteEmRgG9NbaPOGVVq+YCto0SgI0DPdx8IbPf4AppYHf/7K/Q4LH9NgESjZhYAeoH+nz3XcQs+3x2dV8l32+ULtqvJ35amkSXW5DEVmlhdKmOM6pr8MsZQIgAkEJlZAOih/pBZTKZE3YLPdJQIAMlFZhYAeqi/ZBaTpb+t0i/0ODR8sKfPvS8g3WIOZuvquGUGAAej/2fvcQseQG/FXGbwxS9+UfPnz9fll1+ezPEAgGUkenFPf8QteAC9FXNm9v7779e3vvUtXXzxxdq7d28yxwQAlkFmMTG4BQ+gp2LOzN5www0677zzdNVVV6m6ulpPPvmkpkyZksyxAUDGI7MIAOkVVzeD4cOH69VXX9Vjjz2mCy+8UKNHj5bdHn2KtWvXJnSAAGAF9P8EgPSIuzXXxx9/rOeee06FhYX6t3/7t07BLAAAAJAqcUWiTz75pL73ve9pwoQJeu+991RcXJyscQEA0CUeHQzgYDEHs+eee65WrVqlxx57TNOmTUvmmAAA6CTYHtZzaz/V6o+beHQwgIiYg9lQKKQNGzZo6NChyRwPAABdert2j17Z1qpBHpcqCtzy+tsjbdEuqqlM8+gApEvMweyyZcuSOQ4AALrV5Atqyy6fBnlyeXQwgCg8zhYAkPGaWtsUaA8p38WjgwFEI5gFAGS8AleOnPZseVt5dDCAaASzAICMV+Bx6OgSj/b6AmpoDijQHlJDc0B7fAHVVBX2uRKDRl9Q23b71OgLpnsoQMajSSwAwBJOriqScbVrzcdNqmvyy+Ow97lHB7e2hbRkQ51W1zbSsQGIEcEsAMASHPYsXXjSUJ0zuqzP9pldsqFOyzbVq8jjpGMDECPKDAAAllLocWj4YE+fC2QbfUGtrm1Ukcep4jynnPZsFec5VeRxak1tIyUHQDcIZgEA/U4m1qQ2+dvUEmxXvpuODUA8KDMAAPQbmVyTWuDOUa7DLq+/PdJDV6JjA3AkZGYBAP1GR01qls2migK3smw2LdtUryUb6tI9NBV6HBpbVag9/aRjA5AoBLMAgH7BCjWpk8dUaGJ1qYwxqmvyyxjT5zo2AIlGmQEAoF/oqEmtKHBHbc9321XX5FeTvy3t2U9XTrYuqqnUOaNKU9qxodEX7LMdItD3EcwCAPoFK9WkFnocKQkqM7mGGIgVZQYAgH6BmtTOMrmGGIgVwSwAIGXS3RKLmtTPWaGGGIgFZQYAgKTLlNvZ6apJzURWqCEGYkFmFgCQdJl2O7uvPkUsHgfXEB8sE2uIgcPJiGB2/vz5qqqqksvl0rhx47Rq1aqYjlu8eLFsNpsuuOCC5A4QABIo3bfaU43b2elxpN8zaojRV6S9zOD3v/+9Zs2apQULFmjcuHF69NFHNWnSJH3wwQcqKSnp9rja2lp9//vf1+mnn57C0QJAz2XKrfZU43Z2asXze9ZRK7ymtlF1TX55HPZ+W0MM60p7ZvaRRx7RNddcoxkzZqi6uloLFixQbm6uFi5c2O0xoVBIV1xxhe69916NGDEihaMFgJ7LtFvt8eppRpnb2d1LRpY+nt+zjhriW88dpZsnfkG3njtKF9VU9ukvV+h70pqZDQaDWrNmjWbPnh3ZlpWVpQkTJmjlypXdHnffffeppKREV111lf73f//3sK8RCAQUCAQiP3u9XklSOBxWOBzu5Ts4snA4LGNMSl4LycEcWl8mzGGTL6jVtXtV5HGoeMCBLOSB/xqtqd2rs48tVkGGZidb20JaunGHVn/cKH+wXW6HXWOHFer848tjCnoGuu0aO6xAr7xfL8ko32WXt7Vde30BTRhdqoFu+xHnJt1z2OQLqqm1TQWunITMU2+v6eHG2ZPfs4Fuuwa6D4QEybjG6Z4/9F6q5zCe10lrMLt7926FQiGVlpZGbS8tLdXmzZu7POb111/X008/rfXr18f0GvPmzdO9997baXtDQ4NaW1vjHnO8wuGw9u3bJ2OMsrLSnghHDzCH1pcJc1jvbVVuaL8GeRzKCbVFtlc4w9rrC+qTHTsVzHelZWxH8saWBr3zSZMGOXOUm5+tloBf73zkla11n047ujimc5xSZpet1aUtu3wK7A9poD1bNcM9OrnMrl27dh3x+HTNYbA9rLdr9xwYd3tITnu2ji7x6OSqIjnsPR9HIq5pVzL19ywTPoPonVTPYXNzc8z7pr1mNh7Nzc2aOnWqnnzySQ0ePDimY2bPnq1Zs2ZFfvZ6vaqsrFRxcbHy8/OTNdSIcDgsm82m4uJiPsAWxRxaXybMocMTVMvGfWoN2FQ8wBnZ3uAPyGQ7VFlelpGZ2SZfUG/X1yvLOVA5A5xqk5STI4UU0Or6sM46oSDmcQ+tKOtxhjNdc/jc2k/1yrZWDfLkKn+AXfta27VsW6uMq10XnjS0R+dM5DU9VKb+nmXCZxC9k+o5dLli/9KV1mB28ODBys7OVn19fdT2+vp6lZWVddr/n//8p2prazVlypTIto40tN1u1wcffKCRI0dGHeN0OuV0OnWorKyslH2gbDZbSl8PicccWl9P5jCRz6sflOfS2KpBWrapXpJN+e4DNaR7fEFNrC7VoLzMzMruC4TUEgwdWLxls0W257tzVNfk175ASIPyYr+mg/JcPX6vqf4cNvqCWv1xkwZ5XCrOO/DvSHGOXZJNaz5u0jmjy3r0e5Hoa3qwTP494/+j1pfKOYznNdIazDocDtXU1Gj58uWR9lrhcFjLly/XzJkzO+0/atQobdy4MWrbnXfeqebmZv3sZz9TZWVlKoYNoI9LVtcBK64cP3jxVnHe5+89FYu3Dv4y0VHPmUrJ6sKQ7Gtqxd8zoDfSXmYwa9YsTZ8+XWPHjtUpp5yiRx99VD6fTzNmzJAkTZs2TUOGDNG8efPkcrl03HHHRR1fUFAgSZ22A0BPdawGL/I4VVHgltff/n+ZLumimp5/abbi06c6epF2vP/PM30BTawu7Xb8vclqd/llYliBTilL7T9ZyQo6e3pNY2XF3zOgN9IezF5yySVqaGjQ3XffrZ07d+rEE0/Uyy+/HFkUtn37dm5JAEiZQxv8S4oEMmtqG3XOqN4HG4UeR6/PkcgSiCOJJ9OXiKx2V18mXnm/XrZWl4ZWdC5BS5ZkBp2pyJ4m4vcMsIK0B7OSNHPmzC7LCiRpxYoVhz120aJFiR8QgH4r0xv8p+PBC/Fk+nqb1e7+y4TRll0+NfmCKa37TFbQSfYUSJyMCGYBIFOks0Y0FskqgYjFkTJ9ichqd/tlwmVXYH9ITa1tKQ1mkx10kj0Feo/79wBwkEx+Xv2hwaLTnq3iPKeKPE6tqW1M6FOkeqIjEM0/ZLFWvtsuX7BdTf62bo78XLdPC2ttl9OerQJXer5MFHocGj7YQ+AJZCCCWQA4xOQxFZpYXSpjjOqa/DLGZMRq8FiCxWQ8HjVWiXhsbXdfJvb6Ajq6xJORvXgBpBdlBgBwiEytZzxcCYTTnqX//bBBm3Z4U1ZLe6hELZjqqk51wuhSnZzibgYArIH/MwBAN1Jdz3ikDgWHCxYLc3P0xj93p6WW9mCJWDDV1ZeJge7YHnsLoP8hmAWANIunQ0FXweKpI4v07mfepLYTi1Uis9oHf5noeNojAByKYBYA0iyeDgVdBYtN/jat+bhRRQOig8Z0thNjlT6AVGEBGACkUU87FBy8uj4RC68AwKoIZgEgjRLRziqT24lZTTq7QQDoGcoMACCNEvWQhlQ8HrUvS8eT1QAkBsEs0A8dadU8UidR7awytZ2YVaTzyWq9xecZ/R3BLNCPkH3KTInMqrLwKn6JeAxvOvB5Bg4gmAX6EStnn/oysqrJE0vWsqNuuaLAHbU9nd0gYsHnGTiAYBboJ6yafepPyKomTjxZy0TVLacSn2fgc3QzAPqJRKyaB6yiI2uZZbOposCtLJtNyzbVa8mGuk77WrEbBJ9n4HMEs0A/YfVepLRMQqx60rt38pgKTawulTFGdU1+GWMyuhuE1T/PQCJRZgAkQSauLk7UqvlUY5EL4tWTGlir1S1b9fMMJAPBLJBAmR54WbEXKYtcEK/e1MBaqW7Zip9nIBkIZoEEyvTAy2rZJxa5oCf6S9bSap9nIFmomUWXqE+MX0/q9NKl0OPQ8MGejP+Hj0Uu6Cmr1cD2hlU+z0CykJlFlEy/TZ7JrNqrMpNZsWUSMgNZS6D/IDOLKPG0s0E0VhcnnhVbJiF9urqjRNYS6PvIzCKC+sTe6S91eqnGIhccCXeUgP6NYBYR3CbvPQKvxON2MY4k0xdeAkgugllEUJ/YewReyWOllklIHe4oAaBmFhHUJyYOdXpIJLqLdI+OFwDIzCIKt8mBzEEt6JFxRwkAwSyicJscVpOJjw5OFGpBj4yFlwAIZtEl6hOR6fp61pJa0NhxRwno3whmAVhSX89a0l0kdtxRAvo3FoABsBwrPTq4p3gIR/xYeAn0TwSzAHoknSvs+8MKdrqLAEBsKDMAEJdMqFXtLyvYqQUFgCMjmAUQl0yoVe0vK9itUgvalztKAMh8BLMAYpZJK+z7U9YyU7uLZEKWHgAIZgHELJNW2Fsla9mXZUKWHgBYAAYgZpm4wp4V7OnRHzpKALAGglkgQ6WzW0B3WGGPDv2howQAa6DMAMgwmV6H2J9qVfuDni7e6i8dJQBkPoJZIMNkeh0itarWdGjQ2tsvTf2lowSAzEcwC2SQTOoWcCSZusIe0boLWttCRis+2NWrL01k6QFkAoJZIINkUrcA9A1dZfpf3LBDLcF2jSzO69WXJrL0ADIBC8CADJKJ3QJgXd11HPA47dqxr1UOe/Q/AT1dvEVHCQDpRDALZBC6BSCRuus4MHjAgd+jPfsDUdv50gTAighmgQwzeUyFJlaXyhijuia/jDHUIaJHusv0B9uNyvPd2h9o50sTAMujZhbIMNQhIlEO13HggpMqlJOdxeItAJZHMAtkKLoFIBEO13HAlZPNlyYAlkcwCwB92JEy/XxpAmB1BLMA0A8QtALoq1gABgAAAMsimAUAAIBlEcwCAADAsghmAQAAYFkZEczOnz9fVVVVcrlcGjdunFatWtXtvk8++aROP/10FRYWqrCwUBMmTDjs/gAAAOi70h7M/v73v9esWbM0Z84crV27VieccIImTZqkXbt2dbn/ihUrdNlll+m1117TypUrVVlZqX/913/VZ599luKRAwAAIN1sxhiTzgGMGzdOJ598sh577DFJUjgcVmVlpb797W/r9ttvP+LxoVBIhYWFeuyxxzRt2rROfx8IBBQIfP78ca/Xq8rKSjU2Nio/Pz9xb6Qb4XBYDQ0NKi4uVlZW2r87oAeYQ+tjDq2PObQ25s/6Uj2HXq9XhYWF2rdv3xHjtbT2mQ0Gg1qzZo1mz54d2ZaVlaUJEyZo5cqVMZ2jpaVFbW1tGjRoUJd/P2/ePN17772dtjc0NKi1tbVnA49DOBzWvn37ZIzhA2xRzKH1MYfWxxxaG/Nnfamew+bm5pj3TWswu3v3boVCIZWWlkZtLy0t1ebNm2M6x2233aaKigpNmDChy7+fPXu2Zs2aFfm5IzNbXFycssyszWbj26iFMYfWxxxaH3Nobcyf9aV6Dl0uV8z7WvoJYA8++KAWL16sFStWdPumnU6nnE5np+1ZWVkp+0DZbLaUvl4yNfqC/fI57n1pDvsr5tD6mENrY/6sL5VzGM9rpDWYHTx4sLKzs1VfXx+1vb6+XmVlZYc99uGHH9aDDz6oV155RWPGjEnmMCGptS2kJRvqtLq2US3BduU67BpbVajJYyrkyslO9/ASqr8G7AAAWFFag1mHw6GamhotX75cF1xwgaQDaezly5dr5syZ3R734x//WPfff7/++te/auzYsSkabf+2ZEOdlm2qV5HHqYoCt7z+di3bdOBLyEU1lWkeXWJ0F7Cff9zhv1gBAID0SXuuf9asWXryySf1q1/9Su+//76uv/56+Xw+zZgxQ5I0bdq0qAViP/rRj3TXXXdp4cKFqqqq0s6dO7Vz507t378/XW+hz2v0BbW6tlFFHqeK85xy2rNVnOdUkcepNbWNavQF0z3EhOgI2LNsNlUUuJVls2nZpnot3bgj3UMDAADdSHvN7CWXXKKGhgbdfffd2rlzp0488US9/PLLkUVh27dvj6qbePzxxxUMBnXRRRdFnWfOnDm65557Ujl0y+np7fMmf5tagu2qKHBHbc9321XX5FeTv83yt+MPDdglqTjvQPnEmo8bdeLgQSpJ5wABAECX0h7MStLMmTO7LStYsWJF1M+1tbXJH1Af09t61wJ3jnIddnn97ZEAT5K8/nZ5HHYVuHOSOfykOTi4P1zAvqOpRb5gKE2jBAAAh5MRwWxf1uQLqt7bKocnqEF5sbeZSKTe1rsWehwaW1UYOSbffSCw3eMLaGJ1qeWysl0F96PL8+S0Z3cZsOc67PI4+tYiNwAA+gqC2ST5PGDaq9zQfrVs3KexVYNSvvr/sLfPaxt1zqjYgtHJYyoix9Q1+eVx2DWxujSyPZMdWl7RVXD/5j/3qDA3R3t8B54WFxWwjy7RAFfyss90TwAAoOcIZpPk84DJoUEeh1oDtrSs/k9UvasrJ1sX1VTqnFGlSQm8khHQdZeBffczb5fBfVsopNNGDtb7O7xRAfv5x5XJ27gnIWM60vj6arszAACShWA2CaKyoQMcygm1qXiAU5ItrmxoIiS63rXQ40jo2JMZ0HVXXrE/0K4vjyiK2vdAcN+m048t1pQTKqIC63A4LG+vRhLf+KS+0+4MAIBkS3trrr6oIxua747+rpDvtssXbFeTvy1lY+mod93jC6ihOaBAe0gNzQHt8QVUU1WY9tva3bXDWrKhrlfn7a6dWGmeS15/u3Y3B6L2Pzi4L/Q4NHywJ6nXpr+0O+uJRl9Q23b7+vU1AADEjsxsEkRlQwd8HhCla/V/pta7JqqetyvdlVcU5zuV77Zrpzcghz07bYvZ+kO7s3hRdgEA6AmC2SSIXv1vVOEMq8Ef0B5fMC2r/5Nd79pTyQjoOmpvZUy35RXHlAzQ8UMKOtXGpjK476vtznqDsgsAQE8QzCbJ59nQvdrrC8pkO9KeDU10vWtvJTKg6yqrJxntau6iO0F1qS6qqUxrF4G+1u6st5KZpQcA9G0Es0nSkQ09+9hifbJjpyrLy9LWZzZTJTKg6yqrt6s5oMEDHDLGdJmBTXdwn6nlH6nU8YViX0uQsgsAQI8QzCZZgcehYL5LBfxD3KVEBHSHy+oZY3Ttv4yQbLaMKa/okKnlH6lwaCY9O8umem+r3I5sleV/HtD257ILAEBsCGb7ICs14U9EQHek2lvZbBo+2JPIYSdUujPE6dBVJt0XDOm9z7zKtmX1+7ILAEDsCGb7ECuvBu9NQNefFlNZ6YtKd7rLpFeX5+vj3S1qbQtpf6CtX5ZdAADiRzDbh/TX1eD9YTGVVb6oxBJsd5dJH+RxqLUtpGmnVmmgO8fSATsAIHUIZvuI/r4avK8vpsr0LyrxBNtHyqQPG5SbtN/VJl9Q+wIhAmUA6EMIZvuI/t6Evy8vporli4qktL7veILtdGTSW9tCemNLg96ur1dLMJSxmW0AQPwIZvuI/lQ3ejh9cTHV4b6ofLK3Rb9/e7tq97SkrfygJ3cFUp1JX7pxh975pElZzoEZmdkGAPQcwWyG6ekCn/5QN9pfHe6Lyu7moN7atlflA91pC9J6clcglZn0Rl9Qqz9u1CBnjnIGOCWbrVclOH1hER4A9CUEsxkiEQt8+nrdaH/V3ReVHfv8ks2ofKA7rXXSvbkrkIpMepO/Tf5gu3Lzs9V20PZ4S3CssggPAPobgtkMkYgFPn25brS/6+qLyinDB+ndz/Yp3x39MU51nXQm3hU4OHta4M6R22FXS8CvnIPi6nhLcDJ9ER4A9FcEsxkg0Z0I+mLdaH/X1RcVSfrRns0ZUSedKXcFusuenjC0QJu2ehVSQPnunLiD7f7eLQQAMhnBbAbo750IELtDv6hkSkY0U+4KdJc9PfPYwTqxskCr68M9Crb5jAJA5iKYzQB0IkBPZUpGtEM67wocLnu64dN9uqpmkM46obhHfWb5jAJA5iKYzQCZWHMIa8iUjGgmOFz2dEdTi3zBkEZ4HBqUlxX3ufmMAkDmIpjNEJmWYYO1UCd9+OxprsMuj6N3HQf4jAJAZiKYzRBk2DLb/tY2tezxqTDX2e/mxSp9VQ+bPR1dogGu3pUC8BkFgMxEMJthyLBllta2kJa885n+ub1OdYFdcjty+k1vUSv2Ve0ue3r+cWXyNu5JyGvwGQWAzEIwa2FWyZhZ2ZINdXrl/XqN8EjlA93ytob6TW9RK/ZV7S57Gg6H5U334AAASUEwa0FWzJhZUcfq+EEepwa629SSna3ivAMfmb7eW9TqfVXJngJA/xH/sl6kXUfGLMtmU0WBW1k2m5ZtqteSDXXpHlpMGn1BbdvtU6MvmO6hHFbH6vh8V+cnbPmC7Wryt3VzpPVF3nsXTxfr6+8dAGAtZGYtxsoZM6tllCOr41vbVXhQt6f+0FuUvqoAAKsgM2sxVs6YWS2j3LE6fq8voH3+NgXaQ2poDmiPL6CaqsKM/dKQCB3vfY8voIbmQL967wAAayGYtZiDM2YHy/SM2aEZZac9W8V5ThV5nFpT25ixJQeTx1RowuhSGUk79vlljOk3vUUnj6nQxOpSGWNU19S/3nsiWaWsBgCsijIDi7Hqk4is+mx7V062LjxpqLYOtinLU9Cpz2xf7ihBX9XesVpZDQBYFcGsBVnxSURWr8Ec4MpRSZFHWVkHbmb0p0CFzgA9Y8XWZgBgRQSzFmTFjJlVM8rdIVDB4Vh5oSYAWA01sxZW6HFo+GCPZf5R7Cs1mFat/0XqWHmhJgBYDZlZpIwVM8pdsWr9L1LH6mU1AGAlZGbTpD+vcLZaRvlQVu0ogdShtRkApA6Z2RTrTwuH+qq+Vv+L5LDiQk0AsCKC2RRj4VDfQKDStySjxVpfKasBgExHMJtCrHDuOwhU+oZU3CmhtRkAJBc1synECue+x+r1v/2d1R6xDADojGA2hVg4BGQOWqwBQN9AMJtCrHCOX3/u+oDk4k4JAPQN1MymGAuHYkPXByQbvWABoG8gmE2xTFs4lIxV3IlA1wckGy3WAKBvIJhNk3SvcM7kzCddH5Aq3CkBAOsjmO2nMjnzyeNikSqZdqcEABA/FoD1Q5m+ipuuD0g1WqwBgHURzPZDmb6Km64PAAAgVgSz/ZAVMp+Tx1RoYnWpjDGqa/LLGEMtIwAA6ISa2X7ICqu4qWUEAACxyIjM7Pz581VVVSWXy6Vx48Zp1apVh93/j3/8o0aNGiWXy6Xjjz9eS5cuTdFI+w6rZD67q2XkYQoAAEDKgMzs73//e82aNUsLFizQuHHj9Oijj2rSpEn64IMPVFJS0mn/N998U5dddpnmzZunyZMn67e//a0uuOACrV27Vscdd1wa3oE1WTXzmcktxQAAQOqlPTP7yCOP6JprrtGMGTNUXV2tBQsWKDc3VwsXLuxy/5/97Gc699xzdcstt2j06NGaO3euTjrpJD322GMpHnnfYLVV3B0txbJsNlUUuJVls2nZpnot2VCX7qEBAIA0SGtmNhgMas2aNZo9e3ZkW1ZWliZMmKCVK1d2eczKlSs1a9asqG2TJk3S888/3+X+gUBAgUAg8rPX65UkhcNhhcPhXr6DIwuHwzLGpOS1+romX1Cra/eqyONQ8YADwfeB/xqtqd2rs48tVkESgnLm0PqYQ+tjDq2N+bO+VM9hPK+T1mB29+7dCoVCKi0tjdpeWlqqzZs3d3nMzp07u9x/586dXe4/b9483XvvvZ22NzQ0qLW1tYcjj104HNa+fftkjFFWVtoT4ZZW721Vbmi/Bnkcygl93j6swhnWXl9Qn+zYqWC+K+GvyxxaH3NofcyhtTF/1pfqOWxubo5537TXzCbb7NmzozK5Xq9XlZWVKi4uVn5+ftJfPxwOy2azqbi4mA9wLzk8QbVs3KfWgE3FA5yR7Q3+gEy2Q5XlZUnLzDKH1sYcWh9zaG3Mn/Wleg5drtiTU2kNZgcPHqzs7GzV19dHba+vr1dZWVmXx5SVlcW1v9PplNPp7LQ9KysrZR8om82W0tfrqwbluTS2atD/tRSzHdRSLKiJ1aUalJf4rGwH5tD6mEPrYw6tjfmzvlTOYTyvkdbfKIfDoZqaGi1fvjyyLRwOa/ny5Ro/fnyXx4wfPz5qf0latmxZt/sjfZLRPssqLcUAAEBqpL3MYNasWZo+fbrGjh2rU045RY8++qh8Pp9mzJghSZo2bZqGDBmiefPmSZK++93v6owzztBPfvITffWrX9XixYu1evVqPfHEE+l8GzhIMttnWbWlGAAASI60B7OXXHKJGhoadPfdd2vnzp068cQT9fLLL0cWeW3fvj0q1Xzqqafqt7/9re6880794Ac/0DHHHKPnn3+eHrMZpKN9VpHHqYoCt7z+9sjTxi6qqUzIaxR6HASxAABANmOMSfcgUsnr9WrgwIHat29fyhaA7dq1SyUlJf2iTqjRF9SPXt6sLJtNxXkHLdJqDsgYo1vPHWW5ILS/zGGjL9hns939ZQ77MubQ2pg/60v1HMYTr6U9M4u+pcnfppZguyoK3FHb89121TX51eRv63OBktXxVDUAgJXx9QgJVeDOUa7jQJeBg3n97fI47Cpw56RpZOgOT1UDAFgZwSwSqtDj0NiqQu3xBdTQHFCgPaSG5oD2+AKqqSokK5thGn1Bra5tVJHHqeI8p5z2bBXnOVXkcWpNbWNCO1EAAJAMBLNIONpnWUdHWUi+O7riKN9tly/YriZ/WzdHAgCQGaiZRcLRPss6Di4LKc77vD420WUhfXlxGQAgvQhmkTS0z4qWiQFdR1lIR+u0z5+qFtDE6tJej5PFZQCAZCOYBZIsnQFdLAF0R/nHmtpG1TX55XHYE1YWkoqewwCA/o1g1iIyMauH2KQjoIsngE5WWcihi8skRUoZ1tQ26pxRvc/8AgBAMJvhuE1rbekK6HoSQCe6LISewwCAVKCbQYajB6i1paNbQKa026LnMAAgFQhmM1imBCXouXQEdJnSbouewwCAVCCYzWCZEpQcSaMvqG27fQTXXUhHQJdJGVF6DgMAko2a2QyWqh6gPUU9b2yS2S2gK8lutxUPeg4DAJKNYDaDZVJQ0hXaLsUmHQFdqgPoI6HnMAAgWQhmM1ymBSUdaLsUv1QGdGREAQD9BcFshsvUoIS2S9ZARhQA0NcRzFpEpgUlmV7PCwAA+ge6GaBHaLsEAAAyAZlZ9Fim1vMCAID+g2AWPZap9bwAAKD/IJhFr2VaPS8AAOg/qJkFAACAZRHMAgAAwLIIZgEAAGBZBLMAAACwLIJZAAAAWBbBLAAAACyLYBYAAACWRTALAAAAyyKYBQAAgGURzAIAAMCyCGYBAABgWfZ0DyDVjDGSJK/Xm5LXC4fDam5ulsvlUlYW3x2siDm0PubQ+phDa2P+rC/Vc9gRp3XEbYfT74LZ5uZmSVJlZWWaRwIAAIDDaW5u1sCBAw+7j83EEvL2IeFwWHV1dcrLy5PNZkv663m9XlVWVuqTTz5Rfn5+0l8PicccWh9zaH3MobUxf9aX6jk0xqi5uVkVFRVHzAT3u8xsVlaWhg4dmvLXzc/P5wNsccyh9TGH1sccWhvzZ32pnMMjZWQ7ULgCAAAAyyKYBQAAgGURzCaZ0+nUnDlz5HQ60z0U9BBzaH3MofUxh9bG/FlfJs9hv1sABgAAgL6DzCwAAAAsi2AWAAAAlkUwCwAAAMsimAUAAIBlEcwmwPz581VVVSWXy6Vx48Zp1apVh93/j3/8o0aNGiWXy6Xjjz9eS5cuTdFI0Z145vDJJ5/U6aefrsLCQhUWFmrChAlHnHMkX7yfww6LFy+WzWbTBRdckNwB4ojincOmpibdeOONKi8vl9Pp1LHHHsv/T9Mo3vl79NFH9YUvfEFut1uVlZW6+eab1dramqLR4lD/+Mc/NGXKFFVUVMhms+n5558/4jErVqzQSSedJKfTqaOPPlqLFi1K+ji7ZNArixcvNg6HwyxcuNC899575pprrjEFBQWmvr6+y/3feOMNk52dbX784x+bTZs2mTvvvNPk5OSYjRs3pnjk6BDvHF5++eVm/vz5Zt26deb99983V155pRk4cKD59NNPUzxydIh3Djts27bNDBkyxJx++unm3/7t31IzWHQp3jkMBAJm7Nix5vzzzzevv/662bZtm1mxYoVZv359ikcOY+Kfv9/85jfG6XSa3/zmN2bbtm3mr3/9qykvLzc333xzikeODkuXLjV33HGHee6554wk8+c///mw+2/dutXk5uaaWbNmmU2bNpmf//znJjs727z88supGfBBCGZ76ZRTTjE33nhj5OdQKGQqKirMvHnzutz/61//uvnqV78atW3cuHHmW9/6VlLHie7FO4eHam9vN3l5eeZXv/pVsoaII+jJHLa3t5tTTz3VPPXUU2b69OkEs2kW7xw+/vjjZsSIESYYDKZqiDiMeOfvxhtvNGeffXbUtlmzZpnTTjstqeNEbGIJZm+99VbzxS9+MWrbJZdcYiZNmpTEkXWNMoNeCAaDWrNmjSZMmBDZlpWVpQkTJmjlypVdHrNy5cqo/SVp0qRJ3e6P5OrJHB6qpaVFbW1tGjRoULKGicPo6Rzed999Kikp0VVXXZWKYeIwejKHL7zwgsaPH68bb7xRpaWlOu644/TAAw8oFAqlatj4Pz2Zv1NPPVVr1qyJlCJs3bpVS5cu1fnnn5+SMaP3Mimesaf8FfuQ3bt3KxQKqbS0NGp7aWmpNm/e3OUxO3fu7HL/nTt3Jm2c6F5P5vBQt912myoqKjp9qJEaPZnD119/XU8//bTWr1+fghHiSHoyh1u3btWrr76qK664QkuXLtWWLVt0ww03qK2tTXPmzEnFsPF/ejJ/l19+uXbv3q2vfOUrMsaovb1d1113nX7wgx+kYshIgO7iGa/XK7/fL7fbnbKxkJkFeuHBBx/U4sWL9ec//1kulyvdw0EMmpubNXXqVD355JMaPHhwuoeDHgqHwyopKdETTzyhmpoaXXLJJbrjjju0YMGCdA8NMVixYoUeeOAB/eIXv9DatWv13HPP6cUXX9TcuXPTPTRYEJnZXhg8eLCys7NVX18ftb2+vl5lZWVdHlNWVhbX/kiunsxhh4cfflgPPvigXnnlFY0ZMyaZw8RhxDuH//znP1VbW6spU6ZEtoXDYUmS3W7XBx98oJEjRyZ30IjSk89heXm5cnJylJ2dHdk2evRo7dy5U8FgUA6HI6ljxud6Mn933XWXpk6dqquvvlqSdPzxx8vn8+naa6/VHXfcoawscm2Zrrt4Jj8/P6VZWYnMbK84HA7V1NRo+fLlkW3hcFjLly/X+PHjuzxm/PjxUftL0rJly7rdH8nVkzmUpB//+MeaO3euXn75ZY0dOzYVQ0U34p3DUaNGaePGjVq/fn3kz9e+9jWdddZZWr9+vSorK1M5fKhnn8PTTjtNW7ZsiXwRkaQPP/xQ5eXlBLIp1pP5a2lp6RSwdnwxMcYkb7BImIyKZ1K+5KyPWbx4sXE6nWbRokVm06ZN5tprrzUFBQVm586dxhhjpk6dam6//fbI/m+88Yax2+3m4YcfNu+//76ZM2cOrbnSLN45fPDBB43D4TDPPvus2bFjR+RPc3Nzut5CvxfvHB6KbgbpF+8cbt++3eTl5ZmZM2eaDz74wCxZssSUlJSYH/7wh+l6C/1avPM3Z84ck5eXZ373u9+ZrVu3mr/97W9m5MiR5utf/3q63kK/19zcbNatW2fWrVtnJJlHHnnErFu3znz88cfGGGNuv/12M3Xq1Mj+Ha25brnlFvP++++b+fPn05rLyn7+85+bo446yjgcDnPKKaeY//f//l/k78444wwzffr0qP3/8Ic/mGOPPdY4HA7zxS9+0bz44ospHjEOFc8cDhs2zEjq9GfOnDmpHzgi4v0cHoxgNjPEO4dvvvmmGTdunHE6nWbEiBHm/vvvN+3t7SkeNTrEM39tbW3mnnvuMSNHjjQul8tUVlaaG264wTQ2NqZ+4DDGGPPaa691+W9bx7xNnz7dnHHGGZ2OOfHEE43D4TAjRowwzzzzTMrHbYwxNmPI5wMAAMCaqJkFAACAZRHMAgAAwLIIZgEAAGBZBLMAAACwLIJZAAAAWBbBLAAAACyLYBYAAACWRTALAAAAyyKYBQAAgGURzAKABYVCIZ166qm68MILo7bv27dPlZWVuuOOO9I0MgBILR5nCwAW9eGHH+rEE0/Uk08+qSuuuEKSNG3aNL3zzjt6++235XA40jxCAEg+glkAsLD//M//1D333KP33ntPq1at0sUXX6y3335bJ5xwQrqHBgApQTALABZmjNHZZ5+t7Oxsbdy4Ud/+9rd15513pntYAJAyBLMAYHGbN2/W6NGjdfzxx2vt2rWy2+3pHhIApAwLwADA4hYuXKjc3Fxt27ZNn376abqHAwApRWYWACzszTff1BlnnKG//e1v+uEPfyhJeuWVV2Sz2dI8MgBIDTKzAGBRLS0tuvLKK3X99dfrrLPO0tNPP61Vq1ZpwYIF6R4aAKQMmVkAsKjvfve7Wrp0qd555x3l5uZKkn75y1/q+9//vjZu3Kiqqqr0DhAAUoBgFgAs6O9//7vOOeccrVixQl/5ylei/m7SpElqb2+n3ABAv0AwCwAAAMuiZhYAAACWRTALAAAAyyKYBQAAgGURzAIAAMCyCGYBAABgWQSzAAAAsCyCWQAAAFgWwSwAAAAsi2AWAAAAlkUwCwAAAMsimAUAAIBl/X+Kf5OJ2I9r9QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset created: 100 samples\n"
     ]
    }
   ],
   "source": [
    "# Create dataset\n",
    "X = jnp.linspace(0, 1, 100)[:, None]\n",
    "Y = 0.8 * X ** 2 + 0.1 + brainstate.random.normal(0, 0.1, size=X.shape)\n",
    "\n",
    "\n",
    "def dataset(batch_size):\n",
    "    \"\"\"Generate random batches from the dataset.\"\"\"\n",
    "    while True:\n",
    "        idx = brainstate.random.choice(len(X), size=batch_size)\n",
    "        yield X[idx], Y[idx]\n",
    "\n",
    "\n",
    "# Visualize dataset\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.scatter(X, Y, alpha=0.5, s=20)\n",
    "plt.xlabel('X')\n",
    "plt.ylabel('Y')\n",
    "plt.title('Training Dataset: $y = 0.8x^2 + 0.1 + noise$')\n",
    "plt.grid(alpha=0.3)\n",
    "plt.show()\n",
    "\n",
    "print(f\"Dataset created: {len(X)} samples\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a077a6cd13beb75",
   "metadata": {},
   "source": [
    "### Defining Training and Test Steps\n",
    "\n",
    "Now we'll define JIT-compiled training and test functions using JAX transformations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7554f9042667b84a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:11.557478Z",
     "start_time": "2025-10-10T15:54:11.544113Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:05.599432Z",
     "iopub.status.busy": "2026-05-30T17:01:05.599255Z",
     "iopub.status.idle": "2026-05-30T17:01:05.606755Z",
     "shell.execute_reply": "2026-05-30T17:01:05.605984Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training and test functions defined and JIT-compiled\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def train_step(params, counts, batch):\n",
    "    \"\"\"Single training step with gradient descent.\"\"\"\n",
    "    x, y = batch\n",
    "\n",
    "    def loss_fn(params):\n",
    "        # Merge graph and compute loss\n",
    "        model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "        y_pred = model(x)\n",
    "        new_counts = brainstate.graph.treefy_states(model, Count)\n",
    "        loss = jnp.mean((y - y_pred) ** 2)\n",
    "        return loss, new_counts\n",
    "\n",
    "    # Compute gradients\n",
    "    grad, counts = jax.grad(loss_fn, has_aux=True)(params)\n",
    "    \n",
    "    # Update parameters (simple SGD)\n",
    "    params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad)\n",
    "\n",
    "    return params, counts\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def test_step(params, counts, batch):\n",
    "    \"\"\"Evaluate model on test batch.\"\"\"\n",
    "    x, y = batch\n",
    "    model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "    y_pred = model(x)\n",
    "    loss = jnp.mean((y - y_pred) ** 2)\n",
    "    return {'loss': loss}\n",
    "\n",
    "\n",
    "print(\"Training and test functions defined and JIT-compiled\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2530077e5dba7459",
   "metadata": {},
   "source": [
    "### Training the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c63154cbf63b9f43",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:21.259193Z",
     "start_time": "2025-10-10T15:54:11.579934Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:05.608600Z",
     "iopub.status.busy": "2026-05-30T17:01:05.608406Z",
     "iopub.status.idle": "2026-05-30T17:01:16.325030Z",
     "shell.execute_reply": "2026-05-30T17:01:16.324153Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training for 10000 steps...\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step     0: loss = 0.168351\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  1000: loss = 0.011733\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  2000: loss = 0.011768\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  3000: loss = 0.011872\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  4000: loss = 0.011736\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  5000: loss = 0.011749\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  6000: loss = 0.011963\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  7000: loss = 0.011719\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  8000: loss = 0.011735\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step  9000: loss = 0.011739\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training complete!\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "total_steps = 10_000\n",
    "print(f\"Training for {total_steps} steps...\\n\")\n",
    "\n",
    "for step, batch in enumerate(dataset(32)):\n",
    "    params_, counts_ = train_step(params_, counts_, batch)\n",
    "\n",
    "    if step % 1000 == 0:\n",
    "        logs = test_step(params_, counts_, (X, Y))\n",
    "        print(f\"Step {step:5d}: loss = {logs['loss']:.6f}\")\n",
    "\n",
    "    if step >= total_steps - 1:\n",
    "        break\n",
    "\n",
    "print(\"\\nTraining complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "694aba2f960c2b4f",
   "metadata": {},
   "source": [
    "### Verifying the Trained Model\n",
    "\n",
    "Let's restore the model and check how many times it was called:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5d495f172fcca5e0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:21.563382Z",
     "start_time": "2025-10-10T15:54:21.299626Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:16.327177Z",
     "iopub.status.busy": "2026-05-30T17:01:16.327003Z",
     "iopub.status.idle": "2026-05-30T17:01:16.686225Z",
     "shell.execute_reply": "2026-05-30T17:01:16.685302Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model called 10000 times during training\n",
      "Expected: 10000 times\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIoCAYAAABj6NoUAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAhYxJREFUeJzt3Xl8U1X+//F3GrpBaaGlLYtFRFFQEASFAVR0RFEBwWVAVKiIK6AI4oIbKgqIiLigfHVGcWbc0FHxpwwuKC4FRVFchkVUEDdaaKEtSylt7u+PY5KGrmmb5CZ5PR+PPJqc3Jt7Ug81n3zO+RyHZVmWAAAAAAB1FhPqDgAAAABAuCGQAgAAAAA/EUgBAAAAgJ8IpAAAAADATwRSAAAAAOAnAikAAAAA8BOBFAAAAAD4iUAKAAAAAPxEIAUAAAAAfiKQAoAwsGLFCjkcDs9ty5Yttnq9mlx66aWe65xyyikBu059bNu2TWPHjlW7du3UpEkTTz9ff/31UHctLGzZssVnHK1YsaLBr2nn8QIAFRFIAYAqBxYOh0PnnHNOlce+/fbblY699NJLg9vhEDnllFMqvfeqbnUNzBrzQ/Mrr7xSqR+PPfZYtcdblqULLrhAixYt0u+//67y8vJKx3To0MHzWnfddVeD+tdYKv7O6nqzS98BIJI0CXUHAMCu3nrrLf3000/q2LGjT/vDDz8coh6FpwsvvFBdu3aVJGVlZQXsOs8880yltkWLFmnixIlVHr9161bl5OR4Hg8ZMkQnnXSSYmJiPP1FzVJTU/XAAw94Hh9++OENfs1gjRcAaCgCKQCohsvl0mOPPaZ58+Z52r7//nstW7YshL2yj5YtW+rWW2+t8rnU1FTP/TPPPFNnnnlmQPuybds2vf3225Xa16xZo++++67KwOjnn3/2eTx//vxGCQQaQ3FxsZo3b17lcxUDDbeZM2dq586dkqSOHTvqmmuu8Xm+X79+1V5rz549SkxMVEyM/5NUkpOTNXXqVL/Pq0kwxgsANAoLAGB98MEHliTPLSYmxpJkpaSkWLt37/YcN3HiRM8xTqfTcz87O7vSa/7666/W1KlTra5du1rNmjWz4uPjrUMPPdS6+OKLrc8++6zKfuzYscO66qqrrIyMDCshIcHq1auX9eKLL1bq3+bNm33OKy8vt/75z39ap59+upWenm7FxsZarVq1ss4++2zrrbfeqvX9Hvx61RkwYIDnnEMPPbRO52RnZ3vOGTBggGVZlvXMM8/4XL+q2wcffFCn17csy5ozZ47nvKSkJKtt27aexzfccEOl42u7dsU+V3erqLCw0Jo5c6bVu3dvKzk52YqNjbWysrKs7Oxs67vvvqt0/enTp/v8Hnfs2GGNHz/eateunRUTE2M99NBDdX7vlmVZhx56aKXfcXXPT58+3fr444+t0047zUpOTrYkWTt37rQOHDhg3X777dZZZ51ldezY0UpJSbGaNGlipaamWieeeKL1yCOPWKWlpT6vu3nz5mr/mx38Hnft2mVNnTrVat++vRUbG2sddthh1n333We5XC6f16xqvFT13+2ZZ56x3nnnHeuUU06xmjVrZiUlJVlnnnlmlb9vy7Ksp556yuratasVHx9vHXLIIdYNN9xg7d69u9LvBgDqikAKAKzKgcXw4cM99xcsWGBZlvmw3Lx5c0uSddxxx/l8ADs4kPrwww+tli1bVvshPCYmxnrwwQd9ztm5c6fVuXPnKo8fPHhwtYHP3r17rYEDB9b4oX/KlCk1vt9wD6SOPvpoz3kXXXSRNXnyZM/jzMxM68CBAz7HN2Yg9f3331sdOnSo9rj4+Hhr8eLFPtevGGS0atWq0n/3QAZSffv29fkSwB1IFRcX1/qeBw4caJWVlXlet66BVFpamtWlS5cqX/OOO+7w6WtdA6n+/ftbDoej0uulpaVZeXl5PufdcsstVV67d+/eVmZmJoEUgHphah8AVOHiiy/WJ598oh07duixxx7T+PHj9cwzz6i4uFiSdN1111W7gH/Xrl0677zzPFOtEhMTNXbsWCUnJ+uFF17Qzz//LJfLpalTp6pXr14aMGCAJOn222/Xhg0bPK8zYMAADRgwQDk5OXrrrbeq7evkyZP13nvvSZLi4uJ04YUXqlOnTvr222/18ssvy7IszZs3T7169dJFF13UGL8eSVJRUZHmzp1bqT0rK0sjR46s8dwTTjhBDzzwgF566SV98cUXkipPSavrNLvVq1dr3bp1nscXXnihMjMz9dBDD0mScnNz9d///ldDhw71HPPAAw/oxx9/1MKFCz1tt956q1q2bClJ6tq1q7p27eozZe7000/XGWec4XPt8vJynXvuuZ7iGunp6brooouUmpqqt99+WytXrtT+/fs1ZswY9erVq9J6O0nasWOHduzYoYEDB6p///7avn27MjMz6/Te62PVqlVq2rSpLrnkErVr105fffWVnE6nHA6HOnbsqL/85S9q166dWrZsqQMHDmjDhg16+eWXVVZWpvfee0//+c9/NGLECL+umZ+fr507d2rMmDFq27at/v73v2vHjh2SzJrD22+/XXFxcX69Zk5Ojjp37qzzzjtPa9eu1dKlSz3X+sc//qFbbrlFkvT555/r/vvv95yXkZGh7OxsFRcX6+mnn1Zpaalf1wUAj1BHcgBgBwdnaP7f//t/1q233up5vGzZMuuII46wJFnp6elWSUlJtRmphx56yOe1li5d6nkuNzfXSkpK8jw3bNgwy7Is68CBAz7tJ598slVeXm5ZlmW5XC7rjDPOqDKDlJ+fbzVp0sTT/vTTT/u8r/Hjx3ueO+6446p9v/XJSFV3OziLUFOGoabn6uqaa67xvEbLli2t/fv3W5ZlWYcffrin/bzzzqt0Xl1+B7VN+1qyZInneafTaX3//fee58rKyqxu3bp5np88ebLnuYrZGknW9ddfX6/3XlU/a8tIOZ1Oa82aNdW+Vm5urrVkyRLr8ccft+bOnWs98MADVteuXT3nX3bZZZ5j65qRkmTNnz/f89zrr7/u89w333zjea6uGamsrCyrqKjI89xxxx1X5X/vq666ytMeExPjM/Xv4MwoGSkA/qD8OQBUY/z48WrSxCTux40bpx9++EGSdOWVVyo+Pr7a81atWuW5n56errPOOsvzOCMjw+ex+9gNGzZo9+7dnvZRo0Z5Fv87HA5dfPHFVV7rs88+U1lZmefxZZdd5lP2+vHHH/c8t3btWu3du7f2Nx5G9u/frxdffNHz+LzzzvNkNipmxd58803l5+c3+vUrVv0rLy/XkUce6fndN2nSRN9++63n+ZUrV1b7Orfffnuj9606Z511lnr27Fmpfd++fRo7dqzatGmjYcOGafz48Zo6dapuvPFGfffdd57jfv31V7+v6XQ6ddVVV3keH3XUUT7Pu7N+/hg9erRPQY4jjzyyytdzZzwlqVevXjrmmGM8jy+55BLPv3EA8BeBFABUo127djr//PMlSb/99pskKTY2VuPHj6/xvIKCAs/9qqZoVWxzf+DbtWuXzzEZGRnVnlPdtWpjWVajBhOHHnqoLLPW1ufWGJuy1tXrr7/u86H5wgsv9NwfNWqU535paamee+65Rr++P7//7du3V9neqlUrpaWlNVaXatW5c+cq26dNm6ZFixbJ5XLVeP7+/fv9vmZmZqYSEhI8jw/+IqK2a1alQ4cOPo8rvmbF16v4b6t169Y+5zRp0kStWrXy+9oAIFH+HABqNGnSJL300kuex+eff77atm1b4zkVS3/n5uZWer5im3tNTosWLXyOycvLq/ac6q4lmfVSNfUvJSWl2ufC0aJFi3wen3766TUee9111zXq9Sv+/hMSEjRjxoxqj63ud9+sWbNG7VNtqrtexXHerVs3vfDCCzrqqKPUpEkTjRgxQi+//HK9rxkbG+vz2OFw1Pu1/H3Niv+2Dv53VVZW5lmrBQD+IpACgBr07dtXJ5xwgj7//HNJqtMH8X79+mnx4sWSTBbiv//9r2c6X15env773//6HCuZLEFSUpJnet8LL7ygK6+8UjExMbIsq9psSp8+feR0OlVeXi7JfLisal+fLVu2aOPGjUpOTq7rWw+Kih+G/Z12+Pvvv+vdd9+t8/FfffWVvvnmGx177LGN1r+K+zOVlJTomGOO8Zm66fbZZ5/VOB3UDipmK0899VTPFLjt27cHNcvY2I4//nitWbNGkpnm98MPP+iII46QJP373//2mRoLAP4gkAKAWvzzn//Uhg0bFBsbq759+9Z6fHZ2tmbMmOH5YHr++efrsssuU3Jysp5//nlPsORwOHT99ddLMlOMxowZ41nT9NFHH+mvf/2rp2rf8uXLq7xWamqqLrvsMj311FOSpDlz5uiLL75Qv379lJCQoN9++02ffvqpvvrqK2VnZ2vQoEEN/XU0qnbt2nnur1mzRpMmTVJWVpbi4uJqDVr/+c9/egJISRo6dKiaNm3qc4zL5fLJpDzzzDOean517Z97bdyiRYuUmJio5s2b6/DDD9e5556rwYMHq0uXLlq/fr0kafjw4TrvvPN09NFHy+Vy6ccff9RHH32kn3/+Wc8884x69OhR52sH21FHHeVZC/XUU08pJiZGTZs21b/+9a9qpyWGg3HjxunJJ5+UZVkqLy/XySefrDFjxqioqEj/+Mc/Qt09AGGMQAoAatG5c+dq15VUpUWLFnr11Vc1bNgw7dq1S/v27dOCBQt8jomJidGcOXM8pc8l6d5779V7772n77//XpL04Ycf6sMPP5QknXLKKdVmBebPn6/Nmzd7SqC///77ev/99/15iyEzfPhwzZgxQy6XSy6XS4888ogkM/2stkDq2Wef9dzv1KmT3njjjSqPO/nkk/Xxxx9Lkp5//nk98MADdS4wcN5553n+G2zfvl333HOPJGnw4ME699xz1aRJE73++usaNGiQtmzZotLSUp/iF+Hktttu86wr27dvn+bPny9JatOmjU4//XS/sn92csIJJ+jmm2/W7NmzJUl//PGHpxx6z5499dtvv3mmzroLvABAXfAXAwAC4OSTT9Z3332nG264Qcccc4yaNm2quLg4tW/fXhdffLFWrlypG264weecli1b6pNPPtEVV1yh9PR0xcfHq3v37nrmmWc0ffr0aq/VtGlTvf3223r++ed19tlnKzMzU02aNFFiYqIOP/xwXXDBBXryySc1b968QL9tv/Xo0UMvvPCCevbs6VOMoDaffvqpz55bY8eOrfbYis/l5eXVuCfXwSZMmKC77rpLHTt2rDb4OvLII/XNN99ozpw56tevn1q2bCmn06nmzZvr2GOP1eWXX67XXnutUffwCoQLL7xQixcvVvfu3RUbG6u0tDSNHDlSn376aa3rAu1u1qxZevLJJ3XMMccoLi5Obdq00cSJE7V8+XIVFRV5jjt4rSIA1MRhWZYV6k4AAAAEyr59+5SYmFip/c033/TZqDknJ8dn3RsA1ISpfQAAIKLdeuutWrt2rYYOHarDDjtMZWVl+uKLL3z2WTv++OPrtAYSANwIpAAAQERz729W3TrDI444Qi+//HKjlGUHED0IpAAAQEQbPny4cnNz9dlnn2n79u0qKSlRixYt1LVrV5177rm6/PLLK1V8BIDasEYKAAAAAPxE1T4AAAAA8BOBFAAAAAD4KerXSLlcLv3+++9q3rw5i0wBAACAKGZZloqLi9W2bdtaN+mO+kDq999/V1ZWVqi7AQAAAMAmfvnlFx1yyCE1HhP1gVTz5s0lmV9WcnJyiHtjMmTbt29Xenp6rVEwwHiBvxgz8BdjBv5izMBfdhozRUVFysrK8sQINYn6QMo9nS85Odk2gVRJSYmSk5NDPpBgf4wX+IsxA38xZuAvxgz8ZccxU5clP/boKQAAAACEEQIpAAAAAPATgRQAAAAA+Cnq10jVhWVZKisrU3l5ecCv5XK5dODAAZWUlNhmjijsw+l0qkmTJpTqBwAACDECqVqUlpbqjz/+0N69e4NyPcuy5HK5VFxczIdlVKlp06Zq06aN4uLiQt0VAACAqEUgVQOXy6XNmzfL6XSqbdu2iouLC3hw485+kXXAwSzLUmlpqbZv367NmzerU6dOoe4SAABA1CKQqkFpaalcLpeysrLUtGnToFyTQAo1SUxMVGxsrH7++WeVlpaSlQIAAAgRFuHUAWuVYCeMRwAAgNDjExkAAAAA+IlACgAAAAD8RCCFOuvQoYPmz59f5+NXrFghh8OhXbt2BaxP1Vm0aJFatGgR9OsCAAAgOhBIRSCHw1Hj7a677qrX637++ee68sor63x8v3799McffyglJaVe1ws2fwNFAAAARC+q9kWgP/74w3P/pZde0p133qmNGzd62pKSkjz3LctSeXm5mjSpfSikp6f71Y+4uDi1bt3ar3MAAACAcEBGKojy86VNm8zPQGrdurXnlpKSIofD4Xm8YcMGNW/eXP/973/Vq1cvxcfH65NPPtGPP/6oYcOGKTMzU0lJSTrhhBP03nvv+bzuwRkbh8Ohv//97zr33HPVtGlTderUSW+88Ybn+YOn9rmn27399tvq0qWLkpKSdOaZZ/oEfmVlZbruuuvUokULpaWl6eabb1Z2draGDx9e43tetGiR2rdvr6ZNm+rcc89V/kG/5Nre3ymnnKKff/5ZkydP9mTuJCk/P1+jRo1Su3bt1LRpU3Xr1k0vvPCCP/85AAAAEIEIpIJg3z7p2WeladOk6dPNz2efNe2hcsstt2j27Nlav369jj32WO3evVtnn322li9frq+++kpnnnmmhg4dqq1bt9b4OnfffbdGjBihb775RmeffbYuvvhiFRQUVHv83r17NXfuXP3rX//SRx99pK1bt2rq1Kme5++//34999xzeuaZZ5STk6OioiK9/vrrNfbhs88+07hx4zRx4kStXbtWp556qu69916fY2p7f6+++qoOOeQQ3XPPPfrjjz88wV1JSYl69eqlt956S999952uvPJKjR49WqtXr66xTwAAAIEWrC/pUTWm9gXB4sXSkiVSZqbUvr1UWGgeS1J2dmj6dM899+j000/3PE5NTVX37t09j2fMmKHXXntNb7zxhiZOnFjt61x66aUaNWqUJGnmzJl65JFHtHr1ap155plVHn/gwAEtXLhQhx9+uCRp4sSJuueeezzPP/roo5o2bZrOPfdcSdJjjz2mpUuX1vheHn74YZ155pm66aabJElHHnmkVq5cqWXLlnmO6d69e43vLzU1VU6nU82bN/eZjtiuXTufQO/aa6/V22+/rcWLF6t379419gsAACAQ9u0zny9zcqTdu6WkJKl/f2nECCkxMdS9ix5kpAIsP98M8sxMc0tI8N7PyQndNwjHH3+8z+Pdu3dr6tSp6tKli1q0aKGkpCStX7++1ozUscce67nfrFkzJScnKy8vr9rjmzZt6gmiJKlNmzae4wsLC5Wbm+sToDidTvXq1avGPqxfv159+vTxaevbt2+jvL/y8nLNmDFD3bp1U2pqqpKSkvT222/Xeh4AAECguL+kdzrNl/ROp3m8eHGoexZdyEgFWEGB+aagfXvf9pQUaetW83xaWvD71axZM5/HU6dO1bvvvqu5c+fqiCOOUGJioi644AKVlpbW+DqxsbE+jx0Oh1wul1/HW5blZ+/9V9/398ADD+jhhx/W/Pnz1a1bNzVr1kzXX399recBAAAEwsFf0kvmi3rJtA8ZEprPltGIQCrAUlNNurWw0DvIJfM4Kck8bwc5OTm69NJLPVPqdu/erS1btgS1DykpKcrMzNTnn3+uk08+WZLJCH355Zfq0aNHted16dJFn332mU/bp59+6vO4Lu8vLi5O5eXllc4bNmyYLrnkEkmSy+XS999/r6OPPro+bxEAAKBB7PolfTRial+ApaWZOau5ueZWUuK937+/fQZ6p06d9Oqrr2rt2rX6+uuvddFFF9WYWQqUa6+9VrNmzdKSJUu0ceNGTZo0STt37vRU0avKddddp2XLlmnu3LnatGmTHnvsMZ/1UVLd3l+HDh300Ucf6bffftOOHTs857377rtauXKl1q9fr6uuukq5ubmN/8YBAADqoOKX9BXZ7Uv6aEAgFQQjRkjDhknl5eabgvJy83jEiFD3zGvevHlq2bKl+vXrp6FDh2rQoEHq2bNn0Ptx8803a9SoURozZoz69u2rpKQkDRo0SAkV03kH+ctf/qKnnnpKDz/8sLp376533nlHt99+u88xdXl/99xzj7Zs2aLDDz/cs2fW7bffrp49e2rQoEE65ZRT1Lp161pLsQMAAARKuHxJHw0cVjAWqNhYUVGRUlJSVFhYqOTkZJ/nSkpKtHnzZh122GE1fpCvq/x8k25NTa1+kFuWpbKyMjVp0qTGLEy0cLlc6tKli0aMGKEZM2aEuju2UHFcxsXFKS8vTxkZGYqJ4XsR1M7lcjFm4BfGDPzFmAm8SKvaZ6cxU1NscDDWSAVRWhrfEtTm559/1jvvvKMBAwZo//79euyxx7R582ZddNFFoe4aAACALSQmmi10hgyp/Ut6BA6BFGwlJiZGixYt0tSpU2VZlrp27ar33ntPXbp0CXXXAAAAbIUv6UOLQAq2kpWVpZycnFB3AwAAAKgRE1cBAAAAwE8EUgAAAADgJ1sFUh999JGGDh2qtm3byuFw6PXXX6/1nBUrVqhnz56Kj4/XEUccoUWLFgW8nwAAAACim60CqT179qh79+5asGBBnY7fvHmzBg8erFNPPVVr167V9ddfr8svv1xvv/12gHsKAAAAIJrZqtjEWWedpbPOOqvOxy9cuFCHHXaYHnzwQUlSly5d9Mknn+ihhx7SoEGDAtVNAAAAICrVZV/UaGGrQMpfq1at0sCBA33aBg0apOuvv77ac/bv36/9+/d7HhcVFUkyG4G5XC6fY10ulyzL8tyCxX2tKN8rGdVwj0f3mHXfB+qCMQN/MWbgL8ZMZNq3T3rlFWnlSu8mwP36SRdc0PBNgO00ZvzpQ1gHUtu2bVNmZqZPW2ZmpoqKirRv3z4lVvFfddasWbr77rsrtW/fvl0lJSU+bQcOHJDL5VJZWZnKysoat/PVsCxL5eXlkiSHwxGUa9bXhx9+qNNPP115eXlq0aJFnc7p1KmTrr32Wl133XWB7VwdDRw4UN27d/dkNRujf4F+j2VlZXK5XMrPz5fT6VRhYaEsywr5TuAIDy6XizEDvzBm4C/GTGRavlxas0bKzJQ6dpT27DGPmzSRTjutYa9tpzFTXFxc52PDOpCqj2nTpmnKlCmex0VFRcrKylJ6erqSk5N9ji0pKVFxcbGaNGmiJk2C+6uKjY1t0Pljx47Vs88+qyuvvFILFy70eW7ChAl64oknlJ2drWeeeabe13A6nZLk9+8nJiYm6L/P6jgcDjkcDk9/Vq9erWbNmtWpf4sWLdLkyZO1c+dOn3Z/XqM+mjRpopiYGKWlpSkuLk4Oh0Pp6ekh/8OD8OByuRgz8AtjBv5izESeggLpo48kp1OKiTHZqZgYqazMtA8aZKb61ZedxkxCQkKdj7XHp9l6at26tXJzc33acnNzlZycXGU2SpLi4+MVHx9fqT0mJqbSf7iYmBjPB+1gZYcsy/Jcq6HXzMrK0ksvvaT58+d7fh8lJSV64YUX1L59+wZfo2I//Xmdxv59lpaWKi4urt7nV+xPRkaGX+dV/Onmz2vUh7u/7jFb8T5QF4wZ+IsxA38xZiLLzp1ScbH058dHj+RkaetW83yrVg27hl3GjD/XD+vR3bdvXy1fvtyn7d1331Xfvn1D1CN76dmzp7KysvTqq6962l599VW1b99exx13nM+x+/fv13XXXaeMjAwlJCToxBNP1Oeff+5zzNKlS3XkkUcqMTFRp556qrZs2VLpmp988olOOukkJSYmKisrS9ddd5327NlT5z5feumlGj58uO6++25PlvDqq69WaWmp55hTTjlFEydO1PXXX69WrVp5Cot89913Ouuss5SUlKTMzEyNHj1aO3bs8Jy3Z88ejRkzRklJSWrTpo1nOl9FHTp00Pz58z2Pd+3apauuukqZmZlKSEhQ165d9eabb2rFihUaO3asCgsLPYHNXXfdVeVrbN26VcOGDVNSUpKSk5M1YsQIny8A7rrrLvXo0UP/+te/1KFDB6WkpOjCCy/0K7UMAAAQKKmpZk1UYaFve2GhaW9INiqc2SqQ2r17t9auXau1a9dKMuXN165dq61bt0oy0/LGjBnjOf7qq6/WTz/9pJtuukkbNmzQ448/rsWLF2vy5MmB7ejxx0uHHBKYW1aWmhx2mJSVVfm544/3u6uXXXaZz/S9p59+WmPHjq103E033aT//Oc/evbZZ/Xll1/qiCOO0KBBg1RQUCBJ+uWXX3Teeedp6NChWrt2rS6//HLdcsstPq/x448/6swzz9T555+vb775Ri+99JI++eQTTZw40a8+L1++XOvXr9eKFSv0wgsv6NVXX620ru3ZZ59VXFyccnJytHDhQu3atUt//etfddxxx+mLL77QsmXLlJubqxEjRnjOufHGG/Xhhx9qyZIleuedd7RixQp9+eWX1fbD5XLprLPOUk5Ojv79739r3bp1mj17tpxOp/r166f58+crOTlZf/zxh/744w9NnTq1ytcYNmyYCgoK9OGHH+rdd9/VTz/9pJEjR1b63b3++ut688039eabb+rDDz/U7Nmz/fq9AQAABEJamtS/v5Sba24lJd77/ftHcfU+y0Y++OADS1KlW3Z2tmVZlpWdnW0NGDCg0jk9evSw4uLirI4dO1rPPPOMX9csLCy0JFmFhYWVntu3b5+1bt06a9++fb5PtGtnWVLwb+3a1fl9ZWdnW8OGDbPy8vKs+Ph4a8uWLdaWLVushIQEa/v27dawYcM8v9fdu3dbsbGx1nPPPec5v7S01Grbtq01Z84cy7Isa9q0adbRRx/tc42bb77ZkmTt3LnTsizLGjdunHXllVf6HPPxxx9bMTExnt/hoYceaj300EM19js1NdXas2ePp+2JJ56wkpKSrPLycsuyLGvAgAHWcccd53PejBkzrDPOOMOn7ZdffrEkWRs3brSKi4utuLg4a/HixZ7n8/PzrcTERGvSpEmetor9e/vtt62YmBhr48aNVfb1mWeesVJSUiq1V3yNd955x3I6ndbWrVs9z//vf/+zJFmrV6+2LMuypk+fbjVt2tQqKiryHHPjjTdaffr0qfK6FcdleXm59ccff3h+N0BtGDPwF2MG/mLMRKa9ey1r0SLLuuIKyxo1yvxctMi0N5SdxkxNscHBbLVG6pRTTqmx5PeiRYuqPOerr74KYK+q0Lp1wF664ruvtIqoHtdNT0/X4MGDtWjRIlmWpcGDB6vVQZNYf/zxRx04cED9+/f3tMXGxqp3795av369JGn9+vXq06ePz3kHT6H8+uuv9c033+i5557zvp8/S1lu3rxZXbp0qVOfu3fvrqZNm/pcZ/fu3frll1906KGHSpJ69epV6doffPCBkpKSKr3ejz/+qH379qm0tNTnPaSmpuqoo46qth9r167VIYccoiOPPLJO/a7K+vXrlZWVpaysLE/b0UcfrRYtWmj9+vU64YQTJJnpgM2bN/cc06ZNG+Xl5dX7ugAAAI0pMVHKzpaGDGEfKTdbBVJh44svAvfalqWysjJT8a2RCjJcdtllnul1CxYsaJTXrMru3bt11VVXVVn2u/3BqxMbqFmzZpWuPXToUN1///2Vjm3Tpo1++OEHv69RXcGSQDi4SqPD4bDFXgoAAAAVpaURQLnZao0UAuPMM89UaWmpDhw44CnMUNHhhx/uWW/kduDAAX3++ec6+uijJUldunTR6tWrfc779NNPfR737NlT69at0xFHHFHp5k9Vva+//lr79u3zuU5SUpJPVudgPXv21P/+9z916NCh0rWbNWumww8/XLGxsfrss8885+zcuVPff/99ta957LHH6tdff632mLi4OM+eX9Xp0qWLfvnlF/3yyy+etnXr1mnXrl2e3y0AAADCD4FUFHA6nVq/fr3WrVvn2fupombNmumaa67RjTfeqGXLlmndunW64oortHfvXo0bN06SKeyxadMm3Xjjjdq4caOef/75SlMtb775Zq1cuVITJ07U2rVrtWnTJi1ZssTvYhOlpaUaN26c1q1bp6VLl2r69OmaOHFijeUoJ0yYoIKCAo0aNUqff/65fvzxR7399tsaO3asysvLlZSUpHHjxunGG2/U+++/r++++06XXnppja85YMAAnXzyyTr//PP17rvvavPmzfrvf/+rZcuWSTLT8Xbv3q3ly5drx44d2rt3b6XXGDhwoLp166aLL75YX375pVavXq0xY8ZowIABOr4exUMAAABgDwRSUSI5ObnShsMVzZ49W+eff75Gjx6tnj176ocfftDbb7+tli1bSjJT8/7zn//o9ddfV/fu3bVw4ULNnDnT5zWOPfZYffjhh/r+++910kkn6bjjjtOdd96ptm3b+tXX0047TZ06ddLJJ5+skSNH6pxzzvGUFq9O27ZtlZOTo/Lycp1xxhnq1q2brr/+erVo0cITLD3wwAM66aSTNHToUA0cOFAnnnhipbVWB/vPf/6jE044QaNGjdLRRx+tm266yZOF6tevn66++mqNHDlS6enpmjNnTqXzHQ6HlixZopYtW+rkk0/WwIED1bFjR7300kt+/U4AAABgLw6rpuoOUaCoqEgpKSkqLCysFGiUlJRo8+bNOuyww/za5bghrAprpIK1CbCdXHrppdq1a5def/31UHfFtiqOy7i4OOXl5SkjIyPkG9ghPLhcLsYM/MKYgb8YM/CXncZMTbHBwRjdAAAAAOAnAikAAAAA8BPlz2ErVe0VBgAAANgNGSkAAAAA8BOBVB1EeT0O2AzjEQAAIPQIpGoQGxsrSVXuDwSEins8uscnAAAAgo81UjVwOp1q0aKF8vLyJElNmzYNeEnyaC9/jupZlqW9e/cqLy9PLVq0kNPplMvlCnW3AAAAohKBVC1at24tSZ5gKtAsy5LL5VJMTAyBFKrUokULz7gEAABAaBBI1cLhcKhNmzbKyMjQgQMHAn49l8ul/Px8paWlhXxDMthPbGysnE5nqLsBAAAQ9Qik6sjpdAblA6zL5VJsbKwSEhIIpAAAAACb4pM6AAAAAPiJQAoAAAAA/EQgBQAAAAB+IpACAAAAAD8RSAEAAACAnwikAAAAAMBPBFIAAAAA4CcCKQAAAADwE4EUAAAAAPiJQAoAAAAA/EQgBQAAACBkCgqk3383P8NJk1B3AAAAAED02bdPWrxYWrlSat5cKi6W+vWTRoyQEhND3bvakZECAAAAEHSLF0tLlkhOp5Sebn4uWWLawwGBFAAAAICgys+XcnKkzEwpI0OKjTU/MzNNe35+qHtYOwIpAAAAAEFVUCDt3i2lpPi2p6SY9nBYL0UgBQAAACCoUlOlpCSpsNC3vbDQtKemhqZf/iCQAgAAABBUaWlS//5Sbq6UlycdOGB+5uaa9rS0UPewdlTtAwAAABB0I0aYnytXStu3S+Xl0rBh3na7I5ACAAAAEHSJiVJ2tjR4sLR1i0vtO0itWoW6V3XH1D4AAAAAoVFaqtS/z1G3y85UarP9oe6NXwikAAAAAATfRx9Jxx2nmGnTFPvtt9KDD4a6R34hkAIAAAAQPNu3S2PHSgMGSOvWSZKsmBg5du8Occf8QyAFAAAAIPBcLumpp6SjjpIWLfI0W717K3/ZMlkzZ4aub/VAIAUAAACEUH6+tGmT+RmxvvlGOvFE6corpZ07TVuLFtITT8j65BOVdesW0u7VB1X7AAAAgAry86WCArMpbCD3M9q3T1q8WMrJkXbvNhvR9u9vyn8nJgbuukFVXCzddZf08MOmvrnbJZdIc+dKmZkmUxWGCKQAAAAABT+wWbxYWrLExBLt20uFheaxZMqChzXLkl59VZo0SfrtN2/7UUdJTzwhnXpq6PrWSJjaBwAAAMgb2DidJrBxOs3jxYsb/1r5+SZgy8w0t4QE7/2cnDCf5vfTT9KQIdIFF3iDqIQE6d57pa+/joggSiKQAgAAAIIe2BQUmKxXSopve0qKaS8oaNzrBcX+/dJ990nHHCMtXeptP+ss6X//k267TYqPD13/GhmBFAAAAKJesAOb1FQzdbCw0Le9sNC0p6Y27vUC7oMPpB49pNtvl0pKTFu7dtIrr0hvvSV17BjS7gUCgRQAAACiXrADm7Q0s/4qN9fcSkq89/v3D2yRi0aVmyuNHi399a/Shg2mzemUJk+W1q+Xzj9fcjhC28cAIZACAABA1AtFYDNihDRsmClmt3Wr+TlsmGm3PZdLWrhQ6txZ+ve/ve1/+Yu0Zo00b57UvHno+hcEVO0DAAAA5A1gcnJMYJOUFNjAJjHRVOcbMiQ45dYbzVdfSVdfLa1e7W1r2VKaPVu6/HIpJjpyNQRSAAAAgEIX2KSlhUkAVVQk3Xmn9Oijvns/ZWdLc+ZIGRmh61sIEEgBAAAAFYRNYBMsliW9/LJZ9/T77972Ll3MnlADBoSubyEUHXk3AAAARJz8fGnTJvvsuWS3/jSKH3805ctHjvQGUQkJ0syZ0tq1URtESWSkAAAAEGb27TOb5ObkmNLkSUmmIMSIEWZ6XrT3p1Hs32+m6913n7nvNniwmdp32GGh65tNkJECAABAWFm8WFqyxFTZbt/e/FyyxLTTn0awfLl07LFmPZQ7iDrkEOnVV6X/9/8Iov5EIAUAAICwkZ9vMj+ZmeaWkOC9n5MT/Gl1dutPg2zbJl18sTRwoPT996bN6ZRuuMHsCXXuuRG7J1R9EEgBAAAgbBQUmOlzKSm+7Skppr2gILr7Uy/l5dKCBWZPqOef97b37St9+aU0d66ZrwgfBFIAAAAIG6mp5jN9YaFve2GhaU9Nje7++G3NGrOJ7sSJ3jfRsqX05JPSJ5+YKX6oEoEUAAAAwkZaminkkJtrbiUl3vv9+we/bLnd+lNnhYXSdddJvXtLX3zhbb/0UmnjRumKK6JmY936omofAAAAwsqIEeZnTo60davJ/Awb5m2P9v7UyLJMFYzrrzdrotyOPtrsCXXyySHrWrghkAIAAEBYSUyUsrOlIUPMGqTU1NBmfuzWn2pt2iRNmCC9+663LTFRmj7dbLYbFxe6voUhAikAAACEpbQ0ewUsduuPR0mJdP/90qxZvntCDR0qPfKI1KFDyLoWzgikAAAAgEj17rvS+PHSDz9427KyzKa6w4aFrl8RgBVkAAAAQKT54w9p1CjpjDO8QVSTJtJNN5k9oQiiGoyMFAAAABApystN0YjbbpOKirztJ55o2rt2DV3fIgyBFAAAABAJvvhCuvpqszeUW1qa9MADphoG5cwbFb9NAAAAIJzt2mWq8fXu7RtEjRtn9oQaO5YgKgDISAEAAADhyLKkF16QpkwxOwC7detmpvH17x+6vkUB24WmCxYsUIcOHZSQkKA+ffpo9erVNR4/f/58HXXUUUpMTFRWVpYmT56skpKSIPUWAAAACIGNG6XTT5cuvtgbRDVrJs2da7JSIQii8vPNVlX5+f49F65slZF66aWXNGXKFC1cuFB9+vTR/PnzNWjQIG3cuFEZGRmVjn/++ed1yy236Omnn1a/fv30/fff69JLL5XD4dC8efNC8A4AAACAANq3z+wHdf/9Ummpt/3cc6WHHzalzUPQpcWLpZwcafduKSnJxHEjRpjnq3suMTHoXW1Utgqk5s2bpyuuuEJjx46VJC1cuFBvvfWWnn76ad1yyy2Vjl+5cqX69++viy66SJLUoUMHjRo1Sp999llQ+w0AAAAE3LJl0sSJ0o8/etsOPVR67DFpyJCQdWvxYmnJEikzU2rfXiosNI/dqnsuOzs0/W0stgmkSktLtWbNGk2bNs3TFhMTo4EDB2rVqlVVntOvXz/9+9//1urVq9W7d2/99NNPWrp0qUaPHl3tdfbv36/9FXZ0LvqzLKTL5ZLL5Wqkd1N/LpdLlmXZoi+wP8YL/MWYgb8YM/AXYyYAfvtNjilT5HjlFU+T1aSJNHWqrNtuk5o2lUL0+y4okFaulFq3ltwTyBISJIdDev9987iq51aulAYPllJT7TVm/OmDbQKpHTt2qLy8XJmZmT7tmZmZ2rBhQ5XnXHTRRdqxY4dOPPFEWZalsrIyXX311br11lurvc6sWbN09913V2rfvn27LdZWuVwuFRYWyrIsxVBdBbVgvMBfjBn4izEDfzFmGlFZmZo+84yS5syRY/duT3PpX/6iwtmzVX7UUWa+XIXngu3336XmzaX0dCk21tuemCht3Wrut29f+bnt283zZWX2GjPFxcV1PtY2gVR9rFixQjNnztTjjz+uPn366IcfftCkSZM0Y8YM3XHHHVWeM23aNE2ZMsXzuKioSFlZWUpPT1dycnKwul4tl8slh8Oh9PT0kA8k2B/jBf5izMBfjBn4izHTSD77TI7x4+VYu9bTZLVqJWvOHDUZM0ZpDkfo+lZBkyZScbG0d6836yRJeXnSnj3m/s8/V36uvNwEWO6MlF3GTEJCQp2PtU0g1apVKzmdTuVWLN0oKTc3V61bt67ynDvuuEOjR4/W5ZdfLknq1q2b9uzZoyuvvFK33XZblf8h4uPjFR8fX6k9JiYm5P/h3BwOh636A3tjvMBfjBn4izEDfzFmGmDnTunWW6X/+z9T3tztyivlmDVLjtTUBl8iP99MyUtNNfv1NkSrVlK/fmbdk2VJKSlmHVRurjRsmDmmuudatfK+jl3GjD/Xt00gFRcXp169emn58uUaPny4JBOdLl++XBMnTqzynL1791Z6s06nU5JkVRx4AAAAgJ1ZlvTcc9INN5iUjduxx0oLF0p9+zb4EjVV12tIBT13db6cHDNdLynJBEru9tqeC1e2CaQkacqUKcrOztbxxx+v3r17a/78+dqzZ4+nit+YMWPUrl07zZo1S5I0dOhQzZs3T8cdd5xnat8dd9yhoUOHegIqAAAAwNY2bJDGj5c++MDblpQk3XOPdO21Zv5cI6ipul5DKuglJprzhwypOtNV03PhzFaB1MiRI7V9+3bdeeed2rZtm3r06KFly5Z5ClBs3brVJwN1++23y+Fw6Pbbb9dvv/2m9PR0DR06VPfdd1+o3gIAAABQN/v2SffdJ82ZIx044G0//3xp/nzpkEMa7VL5+SYrlJlpbpKpoCeZ9iFDGh7gpKVV/xo1PReubBVISdLEiROrncq3YsUKn8dNmjTR9OnTNX369CD0DAAAAGgkS5eaPaE2b/a2HXaY2RPq7LMb/XIFBWY6X/v2vu0pKWbKXUFB5AU6gcYKQAAAACBYfv1VuuACs4mSO4iKjTUFJr77LiBBlGSm1CUlmel8FRUWmvZGqGERdQikAAAAgEArK5Meekjq0kX6z3+87aecIn39tZni17RpwC6flmYKS+TmmltJifd+//5ko+rDdlP7AAAAgIiyapV0zTUmYHJLT5cefFC65BIpSHtC1aW6HuqOQAoAAAAIhIICado06cknvW0Oh3TlldKsWVLLlkHtTm3V9eAfAikAAACgMVmW9M9/SjfeKG3f7m3v0cPsCdWnT8i6JkVmBb1QYI0UAAAA0FjWrZNOPVW69FJvEJWUZNZHff55yIMoNB4yUgAAAEBD7d0r3Xuv9MADprCE29/+ZoKodu1C1zcEBIEUAAAA0BBvvmn2hPr5Z2/b4YebPaHOPDN0/UJAMbUPAAAAqI+tW6Vzz5WGDvUGUXFx0h13SN9+2yhBVH6+tGmT+Ql7ISMFAAAA+OPAAenhh6W77pL27PG2n3aatGCBdNRRDb7Evn3S4sWmVPnu3WaZVf/+plR5YmKDXx6NgIwUAAAAUFc5OVLPnqYinzuIysyUnntOevfdRgmiJBNELVkiOZ1S+/bm55Ilph32QCAFAAAA1CY/X7r8cunEE6XvvjNtDoc0YYK0YYN00UWNtrFufr6J1zIzzS0hwXs/Jyfw0/yYTlg3TO0DAAAAquNySYsWSTfd5BtZ9Oxp9oQ64YRGv2RBgZnO1769b3tKilmWVVAQmH2gmE7oHzJSAAAAQFW++04aMEAaN84bRCUnS488Iq1eHZAgSpJSU00QU1jo215YaNpTUwNyWaYT+olACgAAAKhozx7p5pul446TPvnE237hhWYa37XXmigjQNLSTCYoN9fcSkq89/v3D0w2KtTTCcMRgRQAAADg9sYb0tFHS3PmeDfWPeII6Z13pBdekNq0CUo3RoyQhg2TysvNdL7ycvN4xIjAXM89nTAlxbc9JcW0FxQE5rrhjDVSAAAAwM8/S9ddZwIpt7g46dZbTXYqISGo3UlMlLKzpSFDTBCTmhqYTJRbxemEFd9qoKcThjMCKQAAAESvAwekhx6S7r5b2rvX23766WZPqE6dQtc3meApkAFUxev072/WREkmE1VYaKYTDhsWnD6EGwIpAAAARJX8fJPlydj4sVJuuUb63/+8T7ZuLc2fb+bQVVHO3H1uoDNEoeCeNpiTY6YTJiUFdjphuCOQAgAAQFRwl/f+evkOnfPJTeq0+RnvkzExZk+oGTMqLxRSw0uDh0MAFuzphOGOQAoAAABRYfGLLu159Gnds+5mJe33Vk/YcdjxavXyQqlXr+rP/bM0eGamKQ1eWOidBpedXf01w3FvpmBNJwx3VO0DAABAxNv18bfqP+0kjf/qCk8QVRKfrBdOXKDb/vqp8jtUH0Q1pDQ4ezNFLgIpAAAARK7du6WpU5Vy6nE6Inelp/mbbhfpsYkb9XX/8Sre66yxvHd9S4OHcm+m/Hxp0yb2fwokpvYBAAAg8liW9PrrpqT5r7/KXTYiN+VILTvncW3ueJokqTC39vLe9S0N7g7A2rf3bU9JMcUcCgoafwpdOE4lDFdkpAAAABBZtmyRzjlHOu886ddfTVt8vL469x5dO+AbfdrsNJWUmNLeubkm0KgpoHGXBncfX9dzKwZgFQVybyamEgYPgRQAAAAaRX6+9MMPUlFRiDpQWirNmiUdfbT05pve9kGDpO++U+fn7tDg8+JVXm4yQuXldS/vPWKEOdafc+sbgNVXKKcSRiOm9gEAAKBBKk4n27NHOuwws49tUKeTffihdM010vr13rY2bcyeUH/7m+RwKFH1L+9d39LgwdybKRRTCaMZgRQAAAAapGJp8KwssyXTG2+Y52oqDd4o8vKkG2+U/vlPb1tMjHTttdI990jJyZVOaUh5b3/PDebeTPVdy4X6YWofAAAA6q2q6WQtWgRhOpnLJT35pNS5s28Q1bu39MUXJhNVRRAVKmlpJksXyIxQsKcSRjsCKQAAANRbdaXBk5NrLg3eIF9/bSKDq66Sdu40bS1aSE88Ia1cKR13XAAuGh7qs5YL9cPUPgAAANRbddPJiooCMJ2suFiaPl165BETIbiNHi098IBJg0W5YE4ljHYEUgAAAKg393SyJUvM45QUM+suN9dUIG+UD/GWJf3nP9L110u//eZt79xZevxx6dRTG+EikaUh68BQN0ztAwAAQINUnE72yy8mkDrnnEaaTvbjj9LZZ5vKe+4gKiFBuu8+M8WPIAohQkYKAAAADVJxOpm7uMQRR5jiefW2f7+ZrnfffaZqgtvZZ0uPPip17NigPgMNRSAFAACARpGWJrVsaSqSN8j770vjx0sbN3rb2rUza6POPVdyOBp4AaDhmNoHAAAAe8jNlS65RDrtNG8Q5XRKkyebjXbPO48gCrZBRgoAAAChVV5u9oSaNs2U/3P7y1+khQul7t1D1zegGgRSAAAACJ0vv5SuuUZavdrb1rKldP/90rhxDVxoBQQOIxMAAADBV1QkTZoknXCCbxCVnS1t2CBdcQVBFGyNjBQAAACCx7Kkl182e0L98Ye3vUsX6YknpAEDQtY1wB+E+QAAAAiOH36QzjxTGjnSG0QlJkqzZklr1xJEIayQkQIAAECd5edLBQVSaqopd14n+/ebNU8zZ5r7bkOGSI8+qvzmHVTws5+vCYQYgRQAAABqtW+ftHixlJMj7d4tJSVJ/ftLI0aYpFK13nvP7Am1aZO37ZBDpEcf1b4zhmnxyw7/X7MO6hXwAX4gkAIAAIBHdQHI4sXSkiVSZqbUvr2pUr5kiXkuO7uKF9q2TZo6VXrhBW+be0+o6dOlpCQtftbP16yDegd8gJ8IpAAAAFBjALJ3r2nPzDQ3SUpIMD9zcswMPU/QVV6ups88I8f99/vuCdWvnykmceyxkkzAVufX9IPfAR9QTxSbAAAAiDL5+WamXX6+t80dgDidJgBxOs3jxYtNhmr3biklxfd1UlJMe0HBnw1r1sjRr5+Sb71VDncQlZoq/f3v0scfe4IoyY/X9PN9VQzOEhK893NyfN8v0FBkpAAAAKJEdVmn006rOTvUr585trDQ2y6Zx0lJUlqTQum6O6QFC+RwubwHjB0rzZkjtWpVqS+pqTW/Zmqq/+/PHZy1b+/bnpIibd1qnme9FBoLGSkAAIAoUV3W6bnnas4OSSbgys01t5KSP+9vs3Sx80Wl9ussPfqo9GcQdeCoo+RasUJ6+ukqgyjJBDRVvmauaa9PwFMxOKuoIcEZUB0yUgAAAFGgpjVJ//uf1KRJzdmhESNMW06Oye50LN+kOzaNV9vX3/Oe0LSpXHfeqfyLLlJGu3a19ung10xKkoYN87b7yx2cuddEpaSY95Cba16XbBQaE4EUAABAFKhp2lthodSjh/Tpp75tBwcg2dnSkIEl0uzZSn1qthwV94QaNkx6+GEpK0vKy6tTnxIT/3zNIY1XqryxgzOgOgRSAAAAUaC2NUkXXeQtylBtAPLuu0obP1764QdvW/v2ZlrfOeeYxxXXSNVRWlrjZYsCEZwBVSGQAgAAiAK1TXs75JAaApDff5emTJFeesn7gk2aSDfcIN1xh9SsWdDfT20aMzgDqkIgBQAAECXqMu3NJwApL5cWLJBuv10qLvYedNJJZk+oY44JWt8BuyGQAgAAiBJ+TXv7/HPp6qulL7/0trVqJT3wgHkRhyMofQbsikAKAAAgytQ47W3XLunWW6WFCyXL8rZffrk0ezbz5YA/EUgBAADABE3PP2/WQlWsutetmwmq+vULXd8AG2JDXgAAgGi3caM0cKB0ySXeIKpZM2nuXGnNGoIooApkpAAAQNTJz6c0tiRp3z5p5kxpzhyptNTbfu653j2hAFSJQAoAAESNffukxYtN1brdu03Vuv79TdW6xMRQ9y7Ili2TJkyQfvrJ29ahg/TYY9LgwSHrFhAumNoHAACixuLFZh8lp9PsI+t0mseLF4e6Z0H022/S3/4mnXWWN4iKjTUFJv73P4IooI4IpAAAQFTIzzeZqMxMc0tI8N7PyTHPR7SyMmn+fKlzZ+mVV7ztAwZIX38t3Xef1LRpyLoHhBsCKQAAEBUKCsx0vpQU3/aUFNNeUBCafgXFp59KJ5wgTZ5s3qwkpadLzz4rffCB1KVLaPsHhCECKQAAEBVSU82aqMJC3/bCQtOemhqafgXUzp1mU91+/aS1a73tV10lbdggjRnDxrpAPRFIAQCAqJCWZgpL5OaaW0mJ937//hFWvc+ypH/9SzrqKOn//s+7sW737tKqVWZfqIiMHIHgoWofAACIGiNGmJ85OdLWrSYTNWyYtz0irF8vjR8vrVjhbUtKku65R7r2WqkJH/+AxmC7jNSCBQvUoUMHJSQkqE+fPlq9enWNx+/atUsTJkxQmzZtFB8fryOPPFJLly4NUm8BAEA4SUyUsrOlWbOku+82P7OzI6T0+d69pvJe9+6+QdT555vgavJkgiigEdnqX9NLL72kKVOmaOHCherTp4/mz5+vQYMGaePGjcrIyKh0fGlpqU4//XRlZGTolVdeUbt27fTzzz+rRYsWwe88AAAIG2lpETaV7623pIkTpS1bvG2HHWb2hDr77JB1qzGweTLsylaB1Lx583TFFVdo7NixkqSFCxfqrbfe0tNPP61bbrml0vFPP/20CgoKtHLlSsXGxkqSOnToEMwuAwAAhM6vv0qTJkmvvupti42VbrrJZKfCuJw5myfD7mwTSJWWlmrNmjWaNm2apy0mJkYDBw7UqlWrqjznjTfeUN++fTVhwgQtWbJE6enpuuiii3TzzTfL6XRWec7+/fu1f/9+z+OioiJJksvlksvlasR3VD8ul0uWZdmiL7A/xgv8xZiBvxgzNlVWJj36qBx3Tpdj7x5Ps3XKKbIWLDB7RUlSCP67NdaYWbxYeuMNs89X+/ZSUZF5LEmjRzdCR2Ebdvo7408fbBNI7dixQ+Xl5crMzPRpz8zM1IYNG6o856efftL777+viy++WEuXLtUPP/yg8ePH68CBA5o+fXqV58yaNUt33313pfbt27erpKSk4W+kgVwulwoLC2VZlmJibLeEDTbDeIG/GDPwF2PGfmK/+ELNb7xZcRvWedp2N2ulr0dP1yE3n6/4BIeUlxey/jXGmCkqkjZtko45RnKv2MjIkFq1Mu0//CAlJzdenxFadvo7U1xcXOdjbRNI1YfL5VJGRoaefPJJOZ1O9erVS7/99pseeOCBagOpadOmacqUKZ7HRUVFysrKUnp6upJt8C/S5XLJ4XAoPT095AMJ9sd4gb8YM/AXY8ZGCgrkmDZNjr//3dPkkkOre1ypl3vepy15LXXOh6HP1jTGmCkqkjZvlrKyzH230lLpl1/M/SqWzyNM2envTEJCQp2PtU0g1apVKzmdTuXm5vq05+bmqnXr1lWe06ZNG8XGxvpM4+vSpYu2bdum0tJSxcXFVTonPj5e8fHxldpjYmJC/h/OzeFw2Ko/sDfGC/zFmIG/GDMhZlnSP/8pTZ0q7djhad7a6ji9PewJ/XZIHzWXlBFn1hMNGRL6ogwNHTNpaVKzZmaz5IqfawsLTXtamsRwjCx2+Tvjz/VtMwTj4uLUq1cvLV++3NPmcrm0fPly9e3bt8pz+vfvrx9++MFnLuP333+vNm3aVBlEAQAAhJV166RTTpEuvdQTRJU3a65nez6shWNX67dD+ngOTUkxRRkKCkLT1cYUVZsnI2zZJpCSpClTpuipp57Ss88+q/Xr1+uaa67Rnj17PFX8xowZ41OM4pprrlFBQYEmTZqk77//Xm+99ZZmzpypCRMmhOotAAAANNzevdK0aWZPqI8+8raPGKHCTzcop9d12lnsO7GosNBUtktNDXJfA2TECLNZcnm52Ty5vDwCN09GWLPN1D5JGjlypLZv364777xT27ZtU48ePbRs2TJPAYqtW7f6pNuysrL09ttva/LkyTr22GPVrl07TZo0STfffHOo3gIAAEDD/L//J117rfTzz962ww+XHn9cOuMMpcpkZZYsMU+lpJggKjfXBBp2zdb4ux+Ue/PkIUPYRwr25LAsywp1J0KpqKhIKSkpKiwstE2xiby8PGVkZIR8jijsj/ECfzFm4C/GTBBt3Spdd503QpKkuDjpllvMrcLmSXbeY+ngMWPnvsIe7PR3xp/YwFYZKQAAIpm/38gjShw4IM2fL911l5nS5zZwoLRggXTkkZVOCadszeLFJjZ07wdVWOiNFbOzQ9s3oCEIpAAACDC+kUe1PvlEuuYa6bvvvG2tW0vz5kkXXig5HDWenpZm3wBKMl8e5OSYIMq9Vai7Cp9dKgwC9UWOHgCAAHN/I+90mm/knU7zePHiUPcMIbNjhzRunHTSSd4gyuGQJk6UNmyQRo2qNYgKBwUF5suDlBTf9kiqMIjoRSAFAEAAHfyNfEKC935OjnkegZGfL23aZLPfscsl/eMfUufO0tNPe9t79ZJWr5YefbRy1BHGUlNNBraw0Lc90ioMIjoxtQ8AgAByfyPfvr1ve0qKqS1QUMDUpsZm26mU335rpvHl5HjbkpOlmTOlq682qcoI494PKtwqDAJ1QUYKAIAA4hv54LPdVMrdu6WbbpJ69vQNokaNMtP4JkyIyCDKjf2gEKnISAEAEEB8Ix9ctitusGSJ2RPql1+8bZ06mT2hBg4MYkdCJ5wqDAL+ICMFAECA8Y188NimuMHPP0vnnCMNH+4NouLjTYnzb76JmiCqorQ0E0MSRCFSkJECACDA+EY+eCpOpXRnoqQgTqUsLZUeeki6+26zWMvtjDPMnlBHHBHgDgAIFgIpAACCxO57/kSCkE6l/PhjUzRi3TpvW5s2ZrPdv/0tIsqZA/Biah8AAIgoQZ9KuX27NHasdPLJ3iAqJka67jpTTGLECIIoIAKRkQIAABElaFMpXS6zF9RNN0k7d3rbTzhBWrjQVOkDELEIpAAAQEQK6FTKb74x0/hWrfK2paSYPaGuuiqiy5kDMJjaBwAAUFe7d0tTp5psU8Ug6uKLzTS+8eMJooAoQUYKAACgNpYlvfaaNGmS9Ouv3vYjjzR7Qp12Wuj6BiAkyEgBAADUZPNmaehQ6fzzvUFUfLw0Y4aZ4kcQBUQlMlIAAABVKS2V5s6V7r3Xd0+oM8+UHntMOvzw0PUNQMgRSAEAABxsxQqz3mn9em9b27bSww+bzBTlzIGox9Q+AAAAt7w8Uzv91FO9QVRMjHT99ebxBRfUO4jKz5c2bTI/AYQ/MlIAAAAul/TUU9Itt0i7dnnb+/Qxe0L16FHvl963T1q8WMrJMUX/kpKk/v3NPr2JiQ3uOYAQISMFAACi29q1Ur9+Zl8odxDVooUJoFaubFAQJZkgaskSUxW9fXvzc8kS0w4gfBFIAQCA6FRcLE2eLPXqJX32mae5ZOQY/fTfjcq/4Cozra8B8vNNJioz09wSErz3c3KY5geEM6b2AQCA6GJZ0iuvmHVPv//uaXYd1VnvDH9Crxacot2PNM4UvIICM52vfXvf9pQUaetW83xaWv3fCoDQISMFAACix48/SmefbaIjdxCVkCDNnKnnbvxaT35/SqNOwUtNNQFZYaFve2GhaU9Nrf9rAwgtAikAABD59u83+0F17SotW+ZtP/tsad065V85TR9/FtfoU/DS0kxWKzfX3EpKvPf79ycbBYQzAikAAKJQVJXifv99qXt36Y47TCQjSYccIv3nP9Kbb0qHHeaZgpeS4ntqSoppLyio/+VHjJCGDZPKy810vvJy83jEiPq/JoDQY40UAABRJKpKcefmSjfcID33nLfN6ZQmTZLuuktq3tzTXHEKXkKC9/DGmIKXmGi2phoyxARkqalkooBIQEYKABA0UZUFsamoKMVdXi498YR01FG+QVTfvtKXX0oPPugTREnBmYKXliZ16kQQBUQKMlIAgICLqiyIjR1cilvyZl9yckzGJOw/5H/5pdkP6vPPvW0tW0r33y+NG1djOXP3VLucHDMFLymJKXgAqkcgBQAIOHcWJDPTZEEKC81jyUx5QnBEdCnuoiKzBuqxxySXy9uenS098ICUnl7rSzAFD4A/CKQAAAEVFVmQMBHIdUAhY1kmUp88WfrjD2/70Ueb6X0nn+z3S6alMSYB1I41UgCAgApkNTT4J+JKcf/wg3TmmdKFF3qDqMREafZs6auv6hVEAUBdkZECAARURGZBwlhErAPav9+seZo509x3GzpUeuQRqUOHkHUNQPSocyDVt29fPfXUU+ratWsg+wMAiDDuLIh7TVRKigmicnPNB/iwy4KEubBfB/Tee9L48ab8o1tWlvToo2ZAAUCQ1Hlq35YtW9SrVy/deuutKnFvZgcAQB2wIan9hF0p7m3bpIsukk4/3RtENWki3XijtG4dQRSAoKtzRmrjxo2aNm2a5syZo5dffllPPPGEBg4cGMi+AQAiRNhnQRA65eXSwoXSrbeaynxu/fubYhLduoWubwCiWp0zUsnJyVqwYIFWrVql5ORkDRo0SKNHj9b27dsD2T8AQAQJuywIQuuLL6Q+faSJE71BVGqq9I9/SB99RBAFIKT8LjZxwgkn6PPPP9ejjz6qO+64Q2+++aaysrIqHedwOPT11183SicBAEAUKSyUbrtNevxxU97c7bLLTJGJVq1C1zcA+FO9qvaVlZVp+/bt2r9/v9LS0pTGV4sAAKChLEt68UVpyhSzJsqta1czje/EE0PXNwA4iN+B1Hvvvafx48frp59+0vjx43XfffepefPmgegbAACIFt9/L02YYKryuTVtKt11l3T99VJsbKh6BgBVqnMgtX37dk2ePFkvvPCCunXrppUrV6p3796B7BsAAIh0JSXSrFlmE93SUm/78OHSww9L7duHrGsAUJM6B1JHHXWUSktLNXv2bE2ZMkVOpzOQ/QIAADaVn99I1RffecdkoX74wdt26KFmT6ihQxvcTwAIpDoHUn/5y1/0+OOPqwO7hQMAEJX27ZMWL5ZycqTdu6WkJFOFfMQIU+K+zn7/XZo82byYW5Mm0tSp0u23S82aNXrfAaCx1TmQWrp0aSD7AQAAbG7xYmnJEikz08y4Kyw0jyWzT1itysqkBQukO+6Qiou97SefbCr0HXNMQPoNAIFQ532kAABA9CooMJmozExzS0jw3s/JMdP9arR6tdS7tykc4Q6iWrWSFi2SVqwgiAIQdgikAABArQoKzHS+lBTf9pQU015QUM2JO3dK11wj/eUv0ldfeduvuELauNGkshyOgPUbAAKlXvtIAQCA6JKaatZEFRaabJRbYaFpT0096ATLkp57TrrhBikvz9t+7LFmT6h+/YLSbwAIFDJSAACgVqmpprBEbq65lZR47/fvf1D1vg0bpNNOk0aP9gZRzZpJDz4orVlDEAUgIpCRAgAAdTJihPmZkyNt3WoyUcOGedu1b590333SnDnSgQPeE887T5o/X8rKCnaXg67RSsMDsD0CKQAAUCeJiWZJ05AhVQQLS5dKEydKmzd7T+jQQXrsMWnw4FB0N6garTQ8gLDB1D4AAOCXtDSpU6c/g6hff5UuuMAES+4gKjZWuvVW6X//i4ogSvKWhnc6TWl4p9M8rrhVFoDIQiAFAAD8V1YmPfSQ1KWL9J//eNsHDJC+/tpM8WvaNHT9C6L8/AaWhgcQlgikAACAfz79VDr+eGnKFDOPTZLS06V//lP64AMTXEWRepeGBxDWCKQAAEDdFBRIV11lqu59/bVpczhM24YNpkpfPfeEys+XNm0Kz+xNxdLwFVVbGh5ARKDYBAAAqJllmWzTTTdJ27d727t3lxYuNJvt1lMkFGlISzN9XrLEPE5JMUFUbq6pakj1PiAykZECAADVW79eqeefr5ixY71BVFKSNG+e9MUXDQqipMgp0jBihAmaystNafjy8oNKwwOIOGSkAABAZXv3SvfeK8fcuYqruCfUBReYPaHatWvwJQ4u0iCZQg2SaR8yJHyyOTWWhgcQkQikAACAr7feMntCbdki94onq2NHOR57TDrrrEa7jLtIQ/v2vu0pKSarU1AQfsFIWpp9+szmwEBgEUgBAADjl1+kSZOk117zNFmxsdozYYKa3nuvHM2aNerlKhZpcGeiJIo0NFQkrDsDwgFrpAAAiHYHDkgPPmjKllcIonTqqbLWrtXum28OyCdwd5GG3FxzKynx3u/fnyxKfUXKujPA7gikAACIZitXmj2hpk6V9uwxbRkZ0r//LS1fLnXuHNDLU6ShcbE5MBA8TO0DACBC1bhGJj9fuuUW6e9/97Y5HNLVV0v33Se1bGnaLCugfaRIQ+OKxHVngF0RSAEAEGFqXCOTYEnPPivdeKO0Y4f3pJ49pSeekHr3Dkmf7VSkIZyx7gwIHqb2AQAQYapbI/POQ/+TBgyQxo71BlHNm0uPPCKtXh2yIAqNh3VnQPCQkQIAIIJUtTdT85g9Gv7pPRr4+jzJKvMePHKk2Vi3bdvQdBYB4V5flpNjpvMlJbHuDAgEW2akFixYoA4dOighIUF9+vTR6tWr63Teiy++KIfDoeHDhwe2gwCAoMrPlzZtYqF8XbjXyKSkmMdHbXxDEx4/WoPWzpHTHUQdcYT0zjvSiy8SRIUBf8e/e93ZrFnS3Xebn9nZlD4HGpvtMlIvvfSSpkyZooULF6pPnz6aP3++Bg0apI0bNyojI6Pa87Zs2aKpU6fqpJNOCmJvAQCBxH44/nOvkXH++rMu/OI6dd74hue5AzFxOnDDNDW95xbfBTSwpYaOf9adAYFlu4zUvHnzdMUVV2js2LE6+uijtXDhQjVt2lRPP/10teeUl5fr4osv1t13362OHTsGsbcAgECKpv1wGivrlpZ8QFcWztEdLx7tE0R9nT5Qb878Vk3n3EUQFSaiafwD4chWGanS0lKtWbNG06ZN87TFxMRo4MCBWrVqVbXn3XPPPcrIyNC4ceP08ccf13iN/fv3a//+/Z7HRUVFkiSXyyWXy9XAd9BwLpdLlmXZoi+wP8YL/BVOY6agwGxx1Lq12dZIMp//HQ7TPnhwZFQg27dPeuUV857cWYd+/aQLLqhH1u3jj+WYMEHH/+9/nqadCa31Sr8HFTd6pC74m8Pv//Z2HzMFBd6y6cEaD8G4ZjiPf7uPGdiPncaMP32wVSC1Y8cOlZeXK9O9OvZPmZmZ2rBhQ5XnfPLJJ/rHP/6htWvX1ukas2bN0t13312pffv27SopKfG7z43N5XKpsLBQlmUpJsZ2CUPYDOMF/gqnMfP776agXHq6FBvrbU9MlLZvN4voy8qqPz9cLF8urVljCkN07Gj2xF2zRmrSRDrttLq9hiM/X83vvVdNX3zR02bFxGjXRZfq58tv1oA2yUpO3q7iYqm42L/+2XXM7N8vffKJtH69qUyXkCB16SKdeKIUHx/+1wzn8W/XMQP7stOYKfbjj6StAil/FRcXa/To0XrqqafUqlWrOp0zbdo0TZkyxfO4qKhIWVlZSk9PV3JycqC6Wmcul0sOh0Pp6ekhH0iwP8YL/BVOY6ZJE/Ohf+9e7zfykpSXJ5WXm6lOdv1Gvq4KCqSPPjJTtmJiTHYqJsZ8QP7oI2nQoFreo8slPfOMHLfcIkdBgafZOv54WY8/rpRevXRsA/to1zHzr39Jb7xhAtDkZPO7fPVV87sbPTr8rxnO49+uYwb2Zacxk+DH1GdbBVKtWrWS0+lUbm6uT3tubq5at25d6fgff/xRW7Zs0dChQz1t7nRckyZNtHHjRh1++OE+58THxyu+iq+NYmJiQv4fzs3hcNiqP7A3xgv8FS5jplUrM8VtyRLJskwVusJCsx/OsGHm+XC3c6f5sNy+vW97crLJOOzcWcP7/PZb6eqrzTyviifOnCnH1VfL4XQ2Wj/tNmbcJd4zMnynvVmWaR8ypPGLLAT7muE+/u02ZmB/dhkz/lzfVqM7Li5OvXr10vLlyz1tLpdLy5cvV9++fSsd37lzZ3377bdau3at53bOOefo1FNP1dq1a5WVlRXM7gNAVAhmKfIRI8yHxvJyE1iUl0fWfjjuCnuFhb7thYWmvcqMw+7d0o03Sscd5xtEXXSRtHGjNGGCSXFFsINLvLulpJj2Csm5sL5mpI9/INzZKiMlSVOmTFF2draOP/549e7dW/Pnz9eePXs0duxYSdKYMWPUrl07zZo1SwkJCeratavP+S1atJCkSu0AgIYJRSly9344Q4Z4F/dHUjnntDTzO1yyxDw+OOvg814tyxx43XXSL7942zt1kh5/XBo4sNH6lZ/v/X23bNloL9toKgagFWfh1BiAhuE1I338A+HOdoHUyJEjtX37dt15553atm2bevTooWXLlnkKUGzdujXkKT8AiEbuUsyZmWYqWmGhNwDIzg7stSN5Pxx3diEnx2QdkpKqyDps2SJde6305pvetvh46dZblT/uJhXsTVBqfsN/R9UFy6ec0rDXbWx+BaBhfM2K147U8Q+EM4dlWVaoOxFKRUVFSklJUWFhoW2KTeTl5SkjI4OAEbVivMBf9R0z+fnStGlmxljFwqq5uWa60axZfNBrqIpZIM/vsrRUmjdPuuceE+W4nXGGSh5coJfWHNGoGcJnn/UGy+5AIS/PpfPOy9OoUfb6OxOKDCkbRNcN/2+Cv+w0ZvyJDWyXkQIA2I97fcjBRRFSUkwWpaAg/AKpKgOXEKqUdfjwQ2n8eGndOm9bmzbS/PnS3/6ml/7paNQMobuYQmamN1h271u0fr35XdmpwEEopr0x1Q5ARQRSAIBahWJ9SKDYPquwfbspJvHss962mBhTROLee6Xk5GqDHqn+FeSqC5aTk82eSXYLpNxCMe2NqXYAJJtV7QMA2JN7fUhurrmVlHjv9+8fXh8q3Wu9nE4TNDid5vHixSHumMslPfmkdNRRvkHUCSdIn38uPfKIiWoUmApy1VUQLCoyQVo4BcsAEAwEUgCAOomEUswHZ3ISErz3c3KCU9K9Sl9/LZ14onTVVWbzKMlERY8/Lq1aJfXs6XN4vcqm16KmYLlLFwIpADgYU/sAAHUSCetD6rrWK2jrp4qLpenTTbapvNzbfvHF0ty5UhWb0UuBqyBXVQXBc84xMR4AwBeBFADAL+G8PqS2tV6JiWZWXcDXT1mW9Oqr0qRJ0m+/eduPPFJ64gnpr3+t9SXqVDbdT1UFyy1bSnl59X9NAIhUBFIAgLDmT/aotkzO8uVB2Cvrp5/MnlBLl3rbEhKk224zRSbi4+v0MoHMEFYMll2uxnlNAIg0BFIAgLBU3+p71WVyTjvNbNfUmJXwfJSWmul6M2aYBUhuZ54pPfaYdPjh9XrZcM4QAkA4I5ACAIQld/U9f7NH1WVyNm0K4F5ZH3xg9oTasMHb1rat9PDD0vnnm82aAABhhap9AICw0xjV99LSpE6dvMFRICrhKTdXGj3arHlyB1ExMdL115vHF1xAEAUAYYpACgAQdgKxj1Kj7pXlckkLF0qdO0v//re3vU8fac0a6aGHpObN/e8kAMA2mNoHAAg7tVXfq++eR41SCe+rr6Srr5ZWr/a2tWgh3X+/dPnlJiMFvwStHD0A+IFACgAQdgK1j1KDKuEVFUl33ik9+qhvqbsxY6QHHpAyMurXqShW34IiABAMBFIAgLAUiH2U3PyqhGdZ0iuvmHVPv//ube/SxewJNWBAwzsUpepbUAQAgoFACkDYY9pPdArkPkp19uOP0oQJ0ttv+3bsjjukG26Q4uKC3KHIcXBBEamRy9HbDH/HgPBDIAUgbDHtB1KI9lHav1+aM0e67z5z323IEOmRR6TDDgtyhyKPu6BIQMrR2wh/x4DwxYpXAGHLPe3H6TQftpxO83jx4lD3DBFt+XLp2GPNeih3EHXIIdJrr0lvvEEQVYP8fLNfV13K0wekHL0N8XcMCF9kpACEpWib9gMb2LbNTNd7/nlvm9MpTZ4sTZ9uPt2jSvXJugSqoIid8HcMCG9kpACEpUDsIwRUqbxcWrDA7AlVMYjq10/68ktTkY8gqkb1zbqMGGGCpvJyM52vvLzxCorYAX/HgPBGRgpAWArUPkJ2xUL0EFmzxuwJ9cUX3rbUVLM+auxY9oSqg4ZkXWxRUCSAou3vGBBpCKQAhKVomPYjsRA9ZAoLTeW9BQt894QaO9ZsrJueHrq+hZnGKBoRkoIiQRAtf8eASEUgBUSxcM9yBHIfIbtgH50gsyzzS7/+erMmyu2YY8yeUCedFLKuhSuyLjWLhr9jQKQikAKiUKRkOSJ92g8L0YNs0yazJ9S773rbmjY1hSQmT5ZiY0PXtzBG1qVmkf53DIhkBFJAFIq0LEekTvuJln10Qq6kxEzXmzXLd0+oc84xe0Idemjo+hYhyLrULlL/jgGRjEAKQRXuU8kiAVmO8MGUqCB4911p/Hjphx+8be3bmwBq2LDQ9SvCkHUBEIkIpBAUkTKVLBKQ5QgfTIkKoD/+kKZMkV580dvWpIlpu/NOqVmz0PUtgpF1ARBJqNuKoGDndvuomOWoiCyHPUX6PjpBV14uPfaY2ROqYhB14onSV1+ZKX4EUQCAOiAjhYBjKpm9kOUIL0yJakRffGH2hFqzxtuWlmY21M3OZk+oRsIUbgDRgkAKAcdUMvth4Xf4YUpUA+zaJd1+u/T446a8udu4cSYDxS+2UTCFG0C0IZBCwLFg3n7IciAqWJb0wgtm3VNurre9a1dp4ULzKR+NJtKqgQJAbZjHgIBzTyXLzTW3khLv/f79+QAfSmlpUqdO/DdABNq4UTr9dOnii71BVNOm0pw50pdfEkQ1soOncCckeO/n5JjnASDSEEghKFgwDyAo9u0zVfeOPVZavtzbPny4tH69dOONbKwbAO4p3Ckpvu0pKaa9oCA0/QKAQGJqH4KCqWQAAm7ZMmnCBOmnnzxN5VmHKve2RxV/wVD+5gQQU7gBRCMCKQQVC+YBNLrffpMmT5ZeftnTZDVpou8GTdX/pd+ugg+bKWkNhQ8CiWqgAKIRgRQAhLmoLTddVmb2hLrjDjN/zO3kk/XGoMf17BfHKDNBap9J4YNgoBoogGhDIAUAYSqqy01/9pnZE2rtWm9bq1bS3LnKHzxGb93qYO+6IGMKN4BoQ7EJAAhT7nLTTqcpN+10mseLF4e6ZwG0c6cJoPr29Q2irrjCVOrLzlbBTgeFD0KIaqAAogWBFACEoagrN21Z0r/+JR11lPR//+fdWPfYY6WVK6Unn/RUNKhY+KAiCh8AABoTgRQAhKGoKje9YYP0179KY8ZI27ebtqQkad48ac0ak52qgL3rAADBwBopAKiC3Qs4REW56b17pfvukx54QDpwwNt+/vnS/PnSIYdUeyqFDwAAgUYgBQAVhEsBh4gvN710qTRxorR5s7ftsMNMlb6zz671dAofAAACjal9AFBBOBVwGDHCBE3l5SbrUl4eAVmXX3+VLrhAGjzYG0TFxkq33SZ9912dgqiKKHwAAAgUMlIA8KeDCzhI9i6bHVFZl7Iy6dFHpTvv9N0T6pRTpMcfl7p0CVnX7MDuU00BIBoRSAHAn9wFHNq3921PSTEZn4ICe36ITUuzZ7/qbNUq6ZprpK+/9ralp0sPPihdconkcISubyEWLlNNASAaMbUPAP5E2ewgKyiQrrpK6tfPG0Q5HGafqI0bpdGjozqIksJrqikARBsCKQD4E2Wzg8SypGeflTp3Nvs/ufXoYbJTTzwhtWwZsu7ZRdTtFQYAYYZACgAqiMgCDnaybp106qnSpZf67gk1f770+edSnz6h7J2tRNVeYQAQhlgjBaBRhfui+Igq4GAne/dKM2ZIc+eawhJuf/ub9NBDUrt2oetbI2usfwNRsVcYAIQxAikAjSLSFsWHfQEHO3nzTenaa6UtW7xtHTtKCxZIZ54Zsm41tsb+NxDxe4UBQJhjah+ARsGieFSydat07rnS0KHeICouTrr9drMnVJgGUfn50qZNldcoBeLfAFNNAcC+yEgBaLBw238JAXbggPTww9Jdd0l79njb//pXsyfUUUeFrGsNUVPGae/ewPwbYKopANgXGSkADcaieHisXCn16iXdeKM3iMrMlJ57TnrvvbANoqSaM06B/jeQliZ16kQQBQB2QiAFoMHYfwnKz5cuv9ykaL791rQ5HNL48dKGDdJFF4X1nlC1lSKX+DcAANGGQApAg7H/UhSzLOmZZ0ym6R//8Lb37Cl99pkpKNGiRci611hqyzhJ/BsAgGjDGikAjcK9+D0nxyyKT0piUXzE++476ZprpE8+8bY1by7dd5/JRDmdoetbI6tLKXL+DQBAdCGQAtAoWBQfRfbske65R5o3z3dPqJEjTVvbtqHrW4DUtRQ5/wYAIHoQSAFoVOy/FOHeeMPsCbV1q7ftiCNMNb7TTw9dv4Kgrhkn/g0AQHQgkAIA1O7nn6XrrjOBlFt8vDRtmnTzzb7z3SIUWVcAQEUEUgCA6h04YKbr3XOP2SzJ7fTTTSGJTp1C17cQIeMEAJAIpAAA1fnoI1NMYt06b1ubNtJDD5n5bGFczhwAgIai/DkAwNf27dLYsdKAAd4gKibGrI1av94UlSCIAgBEOTJSAADD5ZKeftqseSoo8LYff7y0cKHUq1fo+gYAgM0QSAEApG++ka6+Wlq1ytuWnCzNnGnaI2hPKAAAGgNT+wAgmu3eLU2dKvXs6RtEXXSRtHGjNGECQRQAAFUgIwUA0ciypNdeMyXNf/3V237kkWZPqNNOC13fAAAIA7bMSC1YsEAdOnRQQkKC+vTpo9WrV1d77FNPPaWTTjpJLVu2VMuWLTVw4MAajweAaOfculWOc86RzjvPG0TFx5sS5998QxAFAEAd2C6QeumllzRlyhRNnz5dX375pbp3765BgwYpLy+vyuNXrFihUaNG6YMPPtCqVauUlZWlM844Q7/99luQew4ANldaKs2erVYDBsixdKm3fdAg6bvvpDvuMAEVAACole0CqXnz5umKK67Q2LFjdfTRR2vhwoVq2rSpnn766SqPf+655zR+/Hj16NFDnTt31t///ne5XC4tX748yD0HABv78EOpRw/F3HabHCUlpq1NG+mll6T//lc64ojQ9g8AgDBjqzVSpaWlWrNmjaZNm+Zpi4mJ0cCBA7Wq4iLoGuzdu1cHDhxQampqlc/v379f+/fv9zwuKiqSJLlcLrlcrgb0vnG4XC5ZlmWLvsD+GC+oVV6eHDfdJMe//uVpsmJiZE2YYKbyJSeb9VKWFcJOws74OwN/MWbgLzuNGX/6YKtAaseOHSovL1dmZqZPe2ZmpjZs2FCn17j55pvVtm1bDRw4sMrnZ82apbvvvrtS+/bt21Xi/pY2hFwulwoLC2VZlmJibJcwhM0wXlAtl0uJzz2n5jNnyrFrl6e59Ljj9Otttymhb1/FlJRINvi7B3vj7wz8xZiBv+w0ZoqLi+t8rK0CqYaaPXu2XnzxRa1YsUIJCQlVHjNt2jRNmTLF87ioqEhZWVlKT09XcnJysLpaLZfLJYfDofT09JAPJNgf4wVV+vprOcaPl+PTTz1NVosWsu67TzHjximxoIAxgzrj7wz8xZiBv+w0ZqqLIapiq0CqVatWcjqdys3N9WnPzc1V69atazx37ty5mj17tt577z0de+yx1R4XHx+v+CoWU8fExIT8P5ybw+GwVX9gb4wXeBQXS9OnS488IpWXe9svuUSOuXPlyMyU/vyfFWMG/mDMwF+MGfjLLmPGn+vbanTHxcWpV69ePoUi3IUj+vbtW+15c+bM0YwZM7Rs2TIdf/zxwegqAiQ/X9q0yfwEUEeWJf3nP1KXLtJDD3mDqKOOkpYvl/71L+mgKdMAAKBhbJWRkqQpU6YoOztbxx9/vHr37q358+drz549Gjt2rCRpzJgxateunWbNmiVJuv/++3XnnXfq+eefV4cOHbRt2zZJUlJSkpKSkkL2PuCfffukxYulnBxp924pKUnq318aMUJKTAx17+whP18qKJBSU6W0tFD3Brbx00/StddKFcuZJyRIt98uTZ1KOXMAAALEdoHUyJEjtX37dt15553atm2bevTooWXLlnkKUGzdutUn5fbEE0+otLRUF1xwgc/rTJ8+XXfddVcwu44GWLxYWrLEfGnevr1UWGgeS1J2dmj7Fmo1BZl8Ro5i+/dLc+dK997rWzDirLOkxx6TOnYMXd8AAIgCtgukJGnixImaOHFilc+tWLHC5/GWLVsC3yEEVH6+CRIyM72zj9zr/HJypCFDojsDU1OQOXp0aPuGEPngA+maa6SNG71t7dpJDz8snXee5HCErm8AAEQJW62RQnQqKDCZlpQU3/aUFNNeUBCaftnBwUFmQoL3fk5OdP9uolJurome//pXbxDldEqTJ0vr10vnn08QBQBAkBBIoV4asyhEaqqZrlZY6NteWGjaq9lbOSId/HslyIQkyeWSFi6UOneW/v1vb/tf/iKtWSPNmyc1bx66/gEAEIVsObUP9hWIohBpaeY13NPVUlJMEJWbKw0bFh3T+qr7vZ52mjfIrLitQcUgs6wsdP1GEHz1lXT11dLq1d62li2l+++Xxo2TKC0MAEBI8H9gmykokH7/3b6ZBvd6HafTrNdxOs3jxYsb9rojRpigqbxc2rrV/Bw2zLRHg+p+r8uXm4AqN9fcSkq89/v3j65sXdQpKpImTZKOP943iMrONtP6rriCIAoAgBAiI2UT7ozEypVmhk5xsdSvn73KfweyKERiovl8OGRIZJf4rqqEeW2/1zvv9N7futVkokIdZFKKPYAsS3r5Zen666U//vC2H3209MQT0sknh6xrAADAi0DKJtwZidatpfR0ae9e+5X/dq/Xad/etz0lxXzALyho+IfqtLTgfjAPVkBQ05TI2n6v+/ZVH2S6XIHrs7/vwy4Bf1j74Qdp4kTp7be9bYmJJpqeMkWKiwtd3wAAgA8CKRuomJHIyJBiY81Py7JX+e+KRSGqW68TLoIdENRUwnzIkLr9XoMdZFaF/b4CZP9+s+Zp5kxz323IEOnRR6UOHULWNQAAUDUm2NtAuFRmcxeFqG69Tqg/5PsjUGu9qlJbCXMpPH6vtb2PxqjgGJXee0/q1k2aPt0bRGVlSa+9Jr3xBkEUAAA2RSBlA+FU/jsSikIEOyCoS6AcDr/XcAn4w8a2bdJFF0mnn25q3kuynE7tHDdV+R+vk4YPZ08oAABsjKl9NlCx/LfDYaaW5eXZs/x3JBSFCMZar4rqMiUyHH6vkTS1M6TKy82eULfeairz/Sm3U38tPPYJbdzbTUn3sfYMAAC7I5CyCXfmYeVKaft2e2YkKrLDep36CnRAcHABC3/2ybLz75X9vhrBmjVmT6gvvvC2paYqZ9gczds5VhnpMWqfwtozAADCAYGUTbgzEoMHm6xI+/ZSq1ah7lVkClRAUFMBC3dAbKcS5vURKe8j6AoLpdtvlxYsMFVk3C67TAU3369n57ZSRuvG31YAAAAEDoGUzaSmSmVlTJMKtEAEBLVVtLP71L26CIcpiLZiWdKLL5rS5du2eduPOcZM7zvxROVvCu5UUwAA0DgIpBCVGjsgqOtmxXaeuuePSHkfAbVpkzR+vKnK59a0qfbceJd+H3G9UjNjlSbWngEAEK4IpFCtYG1WG0qNFRAEu4AFbGzfPrMn1KxZUmmpp7l8yDC9dsrDemfjodp9r+/UT9aeAQAQfgikUEmwN6uNBGQVoNxc6YknpMcfNxVj3Nq3lx59VP/eeU61Uz9ZewYAQPghkEIlta31QWVUtKu7iMt0fvut9NBD0nPP+WSg1KSJWRt1553KL2mmnGk1T/1k7RkAAOGFQAo+6rrWB5WRVahZxGQ6LUv5v+5TybIVSn/uIcV9+J7v806ndMEFpkpf166SpILf6zb1k7VnAACEDwIp+GCtT/1R0a5mtst0Hjgg7dzpvRUUVP2zwn0rv0Cu/J1KK9tf+fVSUqQrr5QmTqz0D8iuUz8jLjsIAEAQEUjBh10/8IUTsgqVBSzT6XJJRUU1Bj/Vtu3e7fflHJKcB7X93uwI/XLuJPV54lLzj6QKdpv6uX+/9K9/RUB2EACAECKQgg+7feBDZKgx0/mzpZ2/7VNayU7t+qlAxb/sVEp5gZLL/gx4asoU7dplgqkgsJo1006lam9CSx1onqrdSa313TEX6uOUISqznDpiv5RWdRwlyV5TPz/5RHrjDSkjwybZQQAAwhCBFCqx0wc+hJGKU+UOCnra/Vqgy77dqaQ1O9VSBUrct1OJ+woUt3enmpYUKPZFU6ShxZ+3gImLM2nVli3NzX2/prbUVKlFC/3wc5ymTzeBR8VsbXJJ3aa92mXqZ0GBtH69yQxmZJg21kECAOA/AqkI0ZhrHezygQ8h4HKZ9IQf64Y8P2uYKtdU0sDG6qPDIbVo4Rv8uAMe9626gCgx0ZxfD4017TXUUz8LCqSSEik52be9MddBsvYKABANCKTCXCAroYX6Ax/qybLMwKguCKrp565d5vwgKIlNUlnzloprnarNO1tqb0JLWSkttS8xVfsSU5V3oKWKY1M1anxLJR9aIVhKSZFiYoLSx4oiZdpraqoJBAsKGn8dZMRUZgQAoA4IpMKc7SqhofG4p8rVkA1y5OerRW6uHHv2+D5XcT+jQIqNrToLVDFDdFB7gdVS+a6WSm0dp7Q0adMm6e4qpsyV/Dll7q/HSsmdgvN2ahOO014Pzg6lpkpdukivvmpi5sYMCPl7BACIJgRSYYw9n8JAVVXlapoeV/G5OlSVc0hKqPWo2l7koKlydV031LKl1LSp31PlUv+8eR6HUaXIcJr2Wl126IILpBNPlMrKGjcg5O8RACDaEEiFMfZ8ChLLkvburd+6oSBWlfNEHf4UUmjZMmRT5dzCccpcOEx7rSk7NGiQNHp04waE/D0CAEQbAqkwFk7f5NtCaWnlgKeuAVGopsrVMl3O1aKFdpSXq1WnToqJjw9OHwMgHKfM2VlN2aGVK6W+fU3FvsYMCPl7BACINgRSdmJZlW81SEuV+vf781vmqtY6pJr2iOFyma+8CwvNrajIe//gSnNVZY327AlOPytOlasqA3RwQNSQqXIul1x5eSYAC2PhNGUuHNSUHfrll3rtRVyrcMwsAgDQEARSdnLooYr55Re19uOU7D9vlbwuaWxjdCqKNWtmgpu0tJqnyh3cnpwc0qly4SwcpsyFg9qyQ0k1bBzcEGQWAQDRhEAKkS0uruZiCdUFRC1amHOBMFRbdujgPaQaC5lFAEA0IZCyk169ZB1yiA4cOKDY2FjVb9vQCJeUZD4VVrwlJ5ufVRVUaMAGrIh8kbxxbHXZoQsukIqLA3ttMosAgGhAIGUnr70my+VSQV6eMjIy5GB6GOrIvRQsEgOCQIiGjWOryw65XIEPpAAAiAYEUkAY27dPWr5c+ugj8+E4EgOCQIimjWPJDgEAEBikPKJcfr60aZP5ifDzyivSZ59JTqcJCJxOExAsXhzqntnXwaXBExK893Ny+LcAAADqhoxUlIqGqU2RLj/f7AmUmektEuiu0JaTY6Z0kYmojI1jAQBAYyAjFaXcU5uiIZMRqVk3d0DQrJlve0qKaS8oCE2/7K5iafCK2DgWAAD4g4xUFDp4apMUmZmMSM+6uQOCPXt8t60iIKgZG8cCAIDGQEYqCrkzGSkpvu2RlsmI9KxbWprUr5+0a5eUlyeVlJhgIDfXBAoEBNUbMcIETeXlZjpfeTkbxwIAAP+QkYpCFac2uTNRUmRlMqIl63bBBVKTJqZqX8W9gggIasbGsQAAoKEIpKJQNExtipaCAomJ0mmnSYMGsY9UfVAa3F4ieYNkAEDkIZCKUu6MRU5OZGYyoiHrVlFqqtSqVah7AdRPpK9nBABEJgKpKBXpU5uiIetWV3zLD7uLpg2SAQCRg0AqykXy1KZIz7rVhm/5EQ6iZT0jACDyEEghYkV61q02fMuPcBAt6xkBAJGH8ueIeGlpUqdO0fVh7OBv+RMSvPdzciJvc2KELzZIBgCEKwIpIAJFy15hCH/u9YzuPdDYDw0AEC6Y2hdGKBqAuoq2qoUIb9G+nhEAEJ4IpMIARQPgL6oWIpxE+3pGAEB4IpAKAxQNQH3wLT/CTSRXEQUARB4CKZujNDDqi2/50RBMJQYAoGYEUjZHaWA0FN/ywx9MJQYAoG6o2mdzlAYGEEzuqcROp/kCx+k0jxcvDnXPAACwFwIpm6M0MIBgYf8xAADqjkAqDIwYYYoElJeb6Xzl5RQNAND42H8MAIC6Y41UGKBoQGiw2B7Rhv3HAACoOwKpMBIuRQPCPQBhsT2iFfuPAQBQdwRSaDSREoCwbxeiGfuPAQBQNwRSaDSREICwbxeiHVOJAQCoG4pNoFFESrUvFtsDRlqa1KkTQRQAANUhkEKjiJQAhH27AAAAUBcEUmgUkRKAsG8XAAAA6oJACo0ikgIQ9u0CAABAbSg2gUYTKdW+WGwPAACA2hBIodFEWgASLvt2AQAAIPhsObVvwYIF6tChgxISEtSnTx+tXr26xuNffvllde7cWQkJCerWrZuWLl0apJ6iKtFQ7Ss/X9q0KXyqEQIAAKBx2S6QeumllzRlyhRNnz5dX375pbp3765BgwYpLy+vyuNXrlypUaNGady4cfrqq680fPhwDR8+XN99912Qe45osG+f9Oyz0rRp0vTp5uezz5p2AAAARA/bBVLz5s3TFVdcobFjx+roo4/WwoUL1bRpUz399NNVHv/www/rzDPP1I033qguXbpoxowZ6tmzpx577LEg9xzRwL3psNNpNh12Os3jxYtD3TMAAAAEk63WSJWWlmrNmjWaNm2apy0mJkYDBw7UqlWrqjxn1apVmjJlik/boEGD9Prrr1d5/P79+7V//37P46KiIkmSy+WSy+Vq4DtoOJfLJcuybNEX+CookFaulFq3ljIyTFtCguRwmPbBg4Nf5p3xAn8xZuAvxgz8xZiBv+w0Zvzpg60CqR07dqi8vFyZmZk+7ZmZmdqwYUOV52zbtq3K47dt21bl8bNmzdLdd99dqX379u0qKSmpZ88bj8vlUmFhoSzLUkyM7RKGUe3336XmzaX0dCk21tuemCht324qFZaVBbdPjBf4izEDfzFm4C/GDPxlpzFTXFxc52NtFUgFw7Rp03wyWEVFRcrKylJ6erqSk5ND2DPD5XLJ4XAoPT095AMJvpo0kYqLpb17vRkpScrLM3tNtW8fmowU4wX+YMzAX4wZ+IsxA3/ZacwkJCTU+VhbBVKtWrWS0+lUbm6uT3tubq5at25d5TmtW7f26/j4+HjFx8dXao+JiQn5fzg3h8Nhq/7AaNVK6tfPrImyLCklRSosNJsODxtmng8Fxgv8xZiBvxgz8BdjBv6yy5jx5/q2Gt1xcXHq1auXli9f7mlzuVxavny5+vbtW+U5ffv29Tlekt59991qjwcaYsQIEzSVl5upfOXl4bnpMAAAABrGVhkpSZoyZYqys7N1/PHHq3fv3po/f7727NmjsWPHSpLGjBmjdu3aadasWZKkSZMmacCAAXrwwQc1ePBgvfjii/riiy/05JNPhvJtIEJF2qbDAAAAqB/bBVIjR47U9u3bdeedd2rbtm3q0aOHli1b5ikosXXrVp+UW79+/fT888/r9ttv16233qpOnTrp9ddfV9euXUP1FhAC+fnBDWzS0gigAAAAopnDsiwr1J0IpaKiIqWkpKiwsNA2xSby8vKUkZER8jmi4WDfPrOHU06OtHu3lJQk9e9vptolJoa6d4HHeIG/GDPwF2MG/mLMwF92GjP+xAaMboQ1NsgFAABAKBBIIWzl55tMVGamuSUkeO/n5Jjngfx8adMmxgMAAGhctlsjBdRVQYGZzte+vW97SoqpqFdQwDqmaBbt0z4BAEBgkZFC2EpNNR+OCwt92wsLTXuwN8eFvTDtEwAABBKBFMJWWprJMOTmmltJifd+//5ko6IZ0z4BAECgEUghrLFBLqrinvaZkuLbnpJi2gsKQtMvAAAQOVgjhbDGBrmoSsVpnwkJ3namfQIAgMZCRgoRIS1N6tSJIApGKKd9UiUQAIDoQEYKQERyT+/MyTHTPpOSAjvtkyqBAABEFwIpAFXKz7fXdEl/+xPsaZ/uKoGZmaZKYGGheSyZfgAAgMhCIAXAh90yKw3tT1pa4APBg6sESt61WTk5JpizQzAKAAAaD2ukogBrNuAPu+2/ZLf+VIUqgQAARB8yUhHMbpkF2J/dMit26091qBIIAED0ISMVwcLhm3zYi90yK3brT3XYHBoAgOhDIBWhDv4mPyHBez8nh2l+qFrFzEpFocqs2K0/NWFzaAAAogtT+yKU+5v89u1921NSzIe8ggK+JUdl7syKu9pcSooJWnJzTVAQ7DFjt/7UhM2hAQCILgRSESoS12zYrRx3pAr2/kvh1p/aBKNKIAAACD0CqQgVTt/k14aiGcFlt8yK3foDAAAgEUhFtHD7Jr86bHQaGnbLrNitPwAAILoRSEWwSPgmP1zKXwMAACC6ULUvCqSlSZ06hWfAES7lrwEAABBdCKRga+FU/hoAAADRg0AKtsZGpwAAALAj1kjB9iKlaAYAAAAiB4EUbC8SimYAAAAgshBIIWxQ/hoAAAB2wRopAAAAAPATgRQAAAAA+IlACgAAAAD8RCAFAAAAAH4ikAIAAAAAPxFIAQAAAICfCKQAAAAAwE8EUgAAAADgJwIpAAAAAPATgRQAAAAA+IlACgAAAAD8RCAFAAAAAH4ikAIAAAAAPxFIAQAAAICfmoS6A6FmWZYkqaioKMQ9MVwul4qLi5WQkKCYGOJc1IzxAn8xZuAvxgz8xZiBv+w0ZtwxgTtGqEnUB1LFxcWSpKysrBD3BAAAAIAdFBcXKyUlpcZjHFZdwq0I5nK59Pvvv6t58+ZyOByh7o6KioqUlZWlX375RcnJyaHuDmyO8QJ/MWbgL8YM/MWYgb/sNGYsy1JxcbHatm1ba3Ys6jNSMTExOuSQQ0LdjUqSk5NDPpAQPhgv8BdjBv5izMBfjBn4yy5jprZMlBsTVwEAAADATwRSAAAAAOAnAimbiY+P1/Tp0xUfHx/qriAMMF7gL8YM/MWYgb8YM/BXuI6ZqC82AQAAAAD+IiMFAAAAAH4ikAIAAAAAPxFIAQAAAICfCKQAAAAAwE8EUkG2YMECdejQQQkJCerTp49Wr15d4/Evv/yyOnfurISEBHXr1k1Lly4NUk9hF/6MmaeeekonnXSSWrZsqZYtW2rgwIG1jjFEHn//zri9+OKLcjgcGj58eGA7CNvxd8zs2rVLEyZMUJs2bRQfH68jjzyS/z9FGX/HzPz583XUUUcpMTFRWVlZmjx5skpKSoLUW4TaRx99pKFDh6pt27ZyOBx6/fXXaz1nxYoV6tmzp+Lj43XEEUdo0aJFAe+nvwikguill17SlClTNH36dH355Zfq3r27Bg0apLy8vCqPX7lypUaNGqVx48bpq6++0vDhwzV8+HB99913Qe45QsXfMbNixQqNGjVKH3zwgVatWqWsrCydccYZ+u2334Lcc4SKv2PGbcuWLZo6dapOOumkIPUUduHvmCktLdXpp5+uLVu26JVXXtHGjRv11FNPqV27dkHuOULF3zHz/PPP65ZbbtH06dO1fv16/eMf/9BLL72kW2+9Ncg9R6js2bNH3bt314IFC+p0/ObNmzV48GCdeuqpWrt2ra6//npdfvnlevvttwPcUz9ZCJrevXtbEyZM8DwuLy+32rZta82aNavK40eMGGENHjzYp61Pnz7WVVddFdB+wj78HTMHKysrs5o3b249++yzgeoibKY+Y6asrMzq16+f9fe//93Kzs62hg0bFoSewi78HTNPPPGE1bFjR6u0tDRYXYTN+DtmJkyYYP31r3/1aZsyZYrVv3//gPYT9iTJeu2112o85qabbrKOOeYYn7aRI0dagwYNCmDP/EdGKkhKS0u1Zs0aDRw40NMWExOjgQMHatWqVVWes2rVKp/jJWnQoEHVHo/IUp8xc7C9e/fqwIEDSk1NDVQ3YSP1HTP33HOPMjIyNG7cuGB0EzZSnzHzxhtvqG/fvpowYYIyMzPVtWtXzZw5U+Xl5cHqNkKoPmOmX79+WrNmjWf6308//aSlS5fq7LPPDkqfEX7C5TNwk1B3IFrs2LFD5eXlyszM9GnPzMzUhg0bqjxn27ZtVR6/bdu2gPUT9lGfMXOwm2++WW3btq30xwiRqT5j5pNPPtE//vEPrV27Ngg9hN3UZ8z89NNPev/993XxxRdr6dKl+uGHHzR+/HgdOHBA06dPD0a3EUL1GTMXXXSRduzYoRNPPFGWZamsrExXX301U/tQreo+AxcVFWnfvn1KTEwMUc98kZECItTs2bP14osv6rXXXlNCQkKouwMbKi4u1ujRo/XUU0+pVatWoe4OwoTL5VJGRoaefPJJ9erVSyNHjtRtt92mhQsXhrprsKkVK1Zo5syZevzxx/Xll1/q1Vdf1VtvvaUZM2aEumtAg5CRCpJWrVrJ6XQqNzfXpz03N1etW7eu8pzWrVv7dTwiS33GjNvcuXM1e/Zsvffeezr22GMD2U3YiL9j5scff9SWLVs0dOhQT5vL5ZIkNWnSRBs3btThhx8e2E4jpOrzd6ZNmzaKjY2V0+n0tHXp0kXbtm1TaWmp4uLiAtpnhFZ9xswdd9yh0aNH6/LLL5ckdevWTXv27NGVV16p2267TTExfK8PX9V9Bk5OTrZNNkoiIxU0cXFx6tWrl5YvX+5pc7lcWr58ufr27VvlOX379vU5XpLefffdao9HZKnPmJGkOXPmaMaMGVq2bJmOP/74YHQVNuHvmOncubO+/fZbrV271nM755xzPFWSsrKygtl9hEB9/s70799fP/zwgyfolqTvv/9ebdq0IYiKAvUZM3v37q0ULLkDccuyAtdZhK2w+Qwc6moX0eTFF1+04uPjrUWLFlnr1q2zrrzySqtFixbWtm3bLMuyrNGjR1u33HKL5/icnByrSZMm1ty5c63169db06dPt2JjY61vv/02VG8BQebvmJk9e7YVFxdnvfLKK9Yff/zhuRUXF4fqLSDI/B0zB6NqX/Txd8xs3brVat68uTVx4kRr48aN1ptvvmllZGRY9957b6jeAoLM3zEzffp0q3nz5tYLL7xg/fTTT9Y777xjHX744daIESNC9RYQZMXFxdZXX31lffXVV5Yka968edZXX31l/fzzz5ZlWdYtt9xijR492nP8Tz/9ZDVt2tS68cYbrfXr11sLFiywnE6ntWzZslC9hSoRSAXZo48+arVv396Ki4uzevfubX366aee5wYMGGBlZ2f7HL948WLryCOPtOLi4qxjjjnGeuutt4LcY4SaP2Pm0EMPtSRVuk2fPj34HUfI+Pt3piICqejk75hZuXKl1adPHys+Pt7q2LGjdd9991llZWVB7jVCyZ8xc+DAAeuuu+6yDj/8cCshIcHKysqyxo8fb+3cuTP4HUdIfPDBB1V+PnGPk+zsbGvAgAGVzunRo4cVFxdndezY0XrmmWeC3u/aOCyLnCoAAAAA+IM1UgAAAADgJwIpAAAAAPATgRQAAAAA+IlACgAAAAD8RCAFAAAAAH4ikAIAAAAAPxFIAQAAAICfCKQAAAAAwE8EUgAAAADgJwIpAEDUuuSSS5SQkKDvv/++0nOzZ8+Ww+HQm2++GYKeAQDszmFZlhXqTgAAEAp5eXnq3LmzevTooffff9/TvnnzZh1zzDE6++yz9corr4SwhwAAuyIjBQCIWhkZGbr//vv1wQcf6Nlnn/W0jx8/XrGxsXr44YdD2DsAgJ2RkQIARDXLsnTSSSdp48aN2rBhg959912NGjVKjzzyiK699tpQdw8AYFMEUgCAqPe///1Pxx13nIYPH66PP/5YhxxyiD777DPFxDBxAwBQNQIpAAAk3XrrrZo1a5acTqdWr16tnj17hrpLAAAb46s2AAAktWrVSpLUtm1bde3aNcS9AQDYHYEUACDq/fLLL5o+fbq6du2qX375RXPmzAl1lwAANkcgBQCIehMnTpQk/fe//9Xf/vY33Xffffrpp59C3CsAgJ0RSAEAotprr72mN954QzNmzNAhhxyi+fPnKy4uThMmTAh11wAANkaxCQBA1CouLtbRRx+t9PR0ff7553I6nZKkRx55RJMmTdLixYv1t7/9LcS9BADYEYEUACBqTZo0SY899pg+/fRTnXDCCZ728vJy9e7dW9u2bdOGDRvUvHnzEPYSAGBHTO0DAESlNWvWaMGCBRo/frxPECVJTqdTCxcu1LZt23T77beHqIcAADsjIwUAAAAAfiIjBQAAAAB+IpACAAAAAD8RSAEAAACAnwikAAAAAMBPBFIAAAAA4CcCKQAAAADwE4EUAAAAAPiJQAoAAAAA/EQgBQAAAAB+IpACAAAAAD8RSAEAAACAnwikAAAAAMBP/x9fhj3JC0suZQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 1000x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final MSE loss: 0.011899\n"
     ]
    }
   ],
   "source": [
    "# Restore the trained model\n",
    "model = brainstate.graph.treefy_merge(graphdef, params_, counts_)\n",
    "\n",
    "print(f\"Model called {model.count.value} times during training\")\n",
    "print(f\"Expected: {total_steps} times\\n\")\n",
    "\n",
    "# Make predictions\n",
    "y_pred = model(X)\n",
    "\n",
    "# Visualize results\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.scatter(X, Y, alpha=0.5, s=20, label='Training data', color='blue')\n",
    "plt.plot(X, y_pred, color='red', linewidth=2, label='Model prediction')\n",
    "plt.xlabel('X', fontsize=12)\n",
    "plt.ylabel('Y', fontsize=12)\n",
    "plt.title('Model Fit After Training', fontsize=14, fontweight='bold')\n",
    "plt.legend()\n",
    "plt.grid(alpha=0.3)\n",
    "plt.show()\n",
    "\n",
    "# Compute final test loss\n",
    "final_loss = jnp.mean((Y - y_pred) ** 2)\n",
    "print(f\"Final MSE loss: {final_loss:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcfedf3d2d0054c5",
   "metadata": {},
   "source": [
    "## Advanced Graph Patterns\n",
    "\n",
    "Now let's explore some advanced architectural patterns using the graph system.\n",
    "\n",
    "### Skip Connections (ResNet-style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5494b013a7f7b348",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.059453Z",
     "start_time": "2025-10-10T15:54:21.567395Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:16.688440Z",
     "iopub.status.busy": "2026-05-30T17:01:16.688275Z",
     "iopub.status.idle": "2026-05-30T17:01:17.599215Z",
     "shell.execute_reply": "2026-05-30T17:01:17.598491Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet Architecture:\n",
      "  Input shape: (5, 10)\n",
      "  Output shape: (5, 10)\n",
      "  Number of blocks: 3\n",
      "  Total parameters: 0\n"
     ]
    }
   ],
   "source": [
    "class ResidualBlock(brainstate.graph.Node):\n",
    "    \"\"\"Residual block with skip connection.\"\"\"\n",
    "    \n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.linear1 = brainstate.nn.Linear(dim, dim)\n",
    "        self.linear2 = brainstate.nn.Linear(dim, dim)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Main path\n",
    "        residual = x\n",
    "        x = jax.nn.relu(self.linear1(x))\n",
    "        x = self.linear2(x)\n",
    "        \n",
    "        # Skip connection\n",
    "        return jax.nn.relu(x + residual)\n",
    "\n",
    "\n",
    "class ResNet(brainstate.graph.Node):\n",
    "    \"\"\"Simple ResNet with multiple residual blocks.\"\"\"\n",
    "    \n",
    "    def __init__(self, dim, n_blocks):\n",
    "        super().__init__()\n",
    "        self.blocks = [ResidualBlock(dim) for _ in range(n_blocks)]\n",
    "        # Register blocks as attributes\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            setattr(self, f'block_{i}', block)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        for block in self.blocks:\n",
    "            x = block(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# Test ResNet\n",
    "resnet = ResNet(dim=10, n_blocks=3)\n",
    "x = brainstate.random.randn(5, 10)\n",
    "output = resnet(x)\n",
    "\n",
    "print(f\"ResNet Architecture:\")\n",
    "print(f\"  Input shape: {x.shape}\")\n",
    "print(f\"  Output shape: {output.shape}\")\n",
    "print(f\"  Number of blocks: {len(resnet.blocks)}\")\n",
    "\n",
    "# Count parameters\n",
    "params = brainstate.graph.states(resnet, brainstate.ParamState)\n",
    "total_params = sum(jnp.size(p.value) for p in params.values() if hasattr(p.value, 'shape'))\n",
    "print(f\"  Total parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d48d56f17a8fffd",
   "metadata": {},
   "source": [
    "### Multi-Path Networks (Inception-style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a5d74feeaa8a6be1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.721597Z",
     "start_time": "2025-10-10T15:54:22.060930Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:17.600965Z",
     "iopub.status.busy": "2026-05-30T17:01:17.600689Z",
     "iopub.status.idle": "2026-05-30T17:01:18.509409Z",
     "shell.execute_reply": "2026-05-30T17:01:18.508624Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inception Block:\n",
      "  Input shape: (4, 16)\n",
      "  Output shape: (4, 32)\n",
      "  4 parallel paths with concatenated outputs\n",
      "  Total parameters: 0\n"
     ]
    }
   ],
   "source": [
    "class InceptionBlock(brainstate.graph.Node):\n",
    "    \"\"\"Inception-style block with multiple parallel paths.\"\"\"\n",
    "    \n",
    "    def __init__(self, in_dim, out_dim):\n",
    "        super().__init__()\n",
    "        # Path 1: 1x1 (direct projection)\n",
    "        self.path1 = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        \n",
    "        # Path 2: 1x1 -> 3x3 (simulated with linear)\n",
    "        self.path2_a = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        self.path2_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)\n",
    "        \n",
    "        # Path 3: 1x1 -> 5x5 (simulated)\n",
    "        self.path3_a = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        self.path3_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)\n",
    "        \n",
    "        # Path 4: pool -> 1x1\n",
    "        self.path4 = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Execute all paths in parallel\n",
    "        out1 = jax.nn.relu(self.path1(x))\n",
    "        \n",
    "        out2 = jax.nn.relu(self.path2_a(x))\n",
    "        out2 = jax.nn.relu(self.path2_b(out2))\n",
    "        \n",
    "        out3 = jax.nn.relu(self.path3_a(x))\n",
    "        out3 = jax.nn.relu(self.path3_b(out3))\n",
    "        \n",
    "        out4 = jax.nn.relu(self.path4(x))\n",
    "        \n",
    "        # Concatenate outputs\n",
    "        return jnp.concatenate([out1, out2, out3, out4], axis=-1)\n",
    "\n",
    "\n",
    "# Test Inception block\n",
    "inception = InceptionBlock(in_dim=16, out_dim=32)\n",
    "x = brainstate.random.randn(4, 16)\n",
    "output = inception(x)\n",
    "\n",
    "print(f\"Inception Block:\")\n",
    "print(f\"  Input shape: {x.shape}\")\n",
    "print(f\"  Output shape: {output.shape}\")\n",
    "print(f\"  4 parallel paths with concatenated outputs\")\n",
    "\n",
    "# Count parameters\n",
    "params = brainstate.graph.states(inception, brainstate.ParamState)\n",
    "total_params = sum(jnp.size(p.value) for p in params.values() if hasattr(p.value, 'shape'))\n",
    "print(f\"  Total parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0aface45892c8ab",
   "metadata": {},
   "source": [
    "## Graph Analysis and Statistics\n",
    "\n",
    "Let's create a utility function to analyze graph structures:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f0810b6995fbed70",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.915378Z",
     "start_time": "2025-10-10T15:54:22.721597Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:01:18.511639Z",
     "iopub.status.busy": "2026-05-30T17:01:18.511483Z",
     "iopub.status.idle": "2026-05-30T17:01:19.001982Z",
     "shell.execute_reply": "2026-05-30T17:01:19.001180Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph Analysis Comparison:\n",
      "================================================================================\n",
      "Model           Nodes    Depth    Param Tensors   Total Params   \n",
      "--------------------------------------------------------------------------------\n",
      "MLP             4        1        6               0              \n",
      "ResNet          16       2        24              0              \n",
      "Inception       7        1        12              0              \n",
      "\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "def analyze_graph(node):\n",
    "    \"\"\"Analyze computation graph statistics.\"\"\"\n",
    "    stats = {\n",
    "        'total_nodes': 0,\n",
    "        'total_params': 0,\n",
    "        'param_tensors': 0,\n",
    "        'node_types': {},\n",
    "        'depth': 0\n",
    "    }\n",
    "    \n",
    "    def traverse(n, depth=0):\n",
    "        stats['total_nodes'] += 1\n",
    "        stats['depth'] = max(stats['depth'], depth)\n",
    "        \n",
    "        # Count node type\n",
    "        node_type = n.__class__.__name__\n",
    "        stats['node_types'][node_type] = stats['node_types'].get(node_type, 0) + 1\n",
    "        \n",
    "        # Count parameters\n",
    "        params = brainstate.graph.states(n, brainstate.ParamState)\n",
    "        stats['param_tensors'] += len(params)\n",
    "        for p in params.values():\n",
    "            if hasattr(p.value, 'shape'):\n",
    "                stats['total_params'] += jnp.size(p.value)\n",
    "        \n",
    "        # Traverse children\n",
    "        for _, child in brainstate.graph.iter_node(n):\n",
    "            if child is not n:\n",
    "                traverse(child, depth + 1)\n",
    "    \n",
    "    traverse(node)\n",
    "    return stats\n",
    "\n",
    "\n",
    "# Analyze different models\n",
    "models = {\n",
    "    'MLP': MLP(),\n",
    "    'ResNet': ResNet(dim=10, n_blocks=3),\n",
    "    'Inception': InceptionBlock(in_dim=16, out_dim=32)\n",
    "}\n",
    "\n",
    "print(\"Graph Analysis Comparison:\")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{'Model':<15} {'Nodes':<8} {'Depth':<8} {'Param Tensors':<15} {'Total Params':<15}\")\n",
    "print(\"-\" * 80)\n",
    "\n",
    "for name, model in models.items():\n",
    "    stats = analyze_graph(model)\n",
    "    print(f\"{name:<15} {stats['total_nodes']:<8} {stats['depth']:<8} \"\n",
    "          f\"{stats['param_tensors']:<15} {stats['total_params']:<15,}\")\n",
    "\n",
    "print(\"\\n\" + \"=\" * 80)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "840d30c8b943804f",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "`PyGraph` is a data structure specifically designed in the `brainstate` library to provide JAX function transformation support for complex graph structures. Its core value lies in:\n",
    "\n",
    "### 1. Flexible Expression of Graph Structures\n",
    "- Compared to traditional tree structures (pytree), graph structures can represent more complex node relationships and dependencies\n",
    "- Supports cyclic references, making it suitable for expressing complex model structures like RNNs\n",
    "- Nodes can contain arbitrary pytree array data or subgraph structures\n",
    "\n",
    "### 2. Seamless Integration with the JAX Ecosystem\n",
    "- Supports JAX's core functionalities, including automatic differentiation, vectorization, and parallelization\n",
    "- Provides a mechanism for converting between graph structures and pytree via `treefy_split` and `treefy_merge`\n",
    "- Natively supports JAX function transformations such as `jit`, `vmap`, and `grad`\n",
    "\n",
    "### 3. Rich API for Graph Operations\n",
    "- **Structure operations**: `graphdef`, `iter_node`, `iter_leaf`, `nodes`, `states`\n",
    "- **Transformations**: `treefy_states`, `clone`, `treefy_split`, `treefy_merge`, `flatten`, `unflatten`\n",
    "- **Modifications**: `pop_states`, `update_states`\n",
    "- **Conversions**: `graph_to_tree`, `tree_to_graph`\n",
    "\n",
    "### 4. Powerful Filter System\n",
    "- Filter by type: `brainstate.ParamState`, `brainstate.ShortTermState`, etc.\n",
    "- Combine filters: Use tuples/lists for \"any\" logic\n",
    "- Match everything: `...` or `True`\n",
    "- Match nothing: `None` or `False`\n",
    "\n",
    "### 5. Structure preserving state management\n",
    "- Separate graph structure (`graphdef`) from state values (`treefy_states`)\n",
    "- Retrieve states according to state filters\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
