{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff7c015e07e0c6c9",
   "metadata": {},
   "source": [
    "# Graph and Node System\n",
    "\n",
    "In JAX, powerful function transformations are applicable to pytree data structures -- such as dictionaries, lists, tuples, and other tree-like forms. Leveraging pytree, JAX enables efficient automatic differentiation, vectorization, and parallelization. However, many practical applications, particularly complex neural network models and physical systems, are better suited to graph representations rather than simple tree structures. To address this need, the `brainstate` library adopts the graph-based representation in [flax](https://flax.readthedocs.io/) to represent complex model architectures.\n",
    "\n",
    "\n",
    "\n",
    "## What is `pygraph`?\n",
    "\n",
    "`pygraph` is a specialized data structure within `brainstate` designed to facilitate JAX transformations for graph-based models. Unlike conventional tree structures, graphs can articulate a broader range of node relationships and dependencies, making them ideal for more intricate model architectures. In scenarios where a model's state relies on the interactions among multiple nodes, graph structures offer enhanced flexibility in representation.\n",
    "\n",
    "The `pygraph` module utilizes **`brainstate.graph.Node`** as its foundational element, constructing graph structures by defining relationships (edges) between nodes. Each node can hold arbitrary pytree array data or nested `pygraph` substructures, promoting a more adaptable and modular approach to graph construction.\n",
    "\n",
    "We can think of `brainstate.graph.Node` as a **container class**, where its attributes denote its leaf nodes. These containers can reference one another, thereby facilitating the creation of complex graph structures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "21085866d34afe68",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.229813Z",
     "start_time": "2025-10-10T15:54:05.432653Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import brainstate\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "brainstate.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a69bd0578a8467",
   "metadata": {},
   "source": [
    "## Basic `brainstate.graph.Node` Example\n",
    "\n",
    "Let's start by defining a simple class that inherits from `brainstate.graph.Node`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "77f9ae362b940fe1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.506364Z",
     "start_time": "2025-10-10T15:54:07.238063Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node A created:\n",
      "  w shape: (2, 3)\n",
      "  b shape: (3,)\n",
      "  b type: ShortTermState\n"
     ]
    }
   ],
   "source": [
    "class A(brainstate.graph.Node):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.random.rand(2, 3)\n",
    "        self.b = brainstate.ShortTermState(brainstate.random.rand(3))\n",
    "\n",
    "# Create an instance\n",
    "a = A()\n",
    "\n",
    "print(\"Node A created:\")\n",
    "print(f\"  w shape: {a.w.shape}\")\n",
    "print(f\"  b shape: {a.b.value.shape}\")\n",
    "print(f\"  b type: {type(a.b).__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b58124d7b328ac",
   "metadata": {},
   "source": [
    "In the code above, we define a class `A` that inherits from `brainstate.graph.Node`. Within the class's initialization function, we define two attributes, `w` and `b`, which represent a randomly generated 2×3 matrix and a randomly generated array of length 3, respectively. Both of these data structures are JAX pytree structures.\n",
    "\n",
    "### Cyclic References\n",
    "\n",
    "It is noteworthy that **`brainstate.graph.Node` allows us to create cyclic references**. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "9270a4ee302d629e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.540157Z",
     "start_time": "2025-10-10T15:54:07.537307Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Created cyclic reference: a.self = a\n",
      "Are they the same object? True\n"
     ]
    }
   ],
   "source": [
    "# Create a cyclic reference\n",
    "a.self = a\n",
    "\n",
    "print(\"Created cyclic reference: a.self = a\")\n",
    "print(f\"Are they the same object? {a.self is a}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e66524c99968322",
   "metadata": {},
   "source": [
    "We set the `self` attribute of `a` to `a`, thereby creating a cyclic reference. This type of referencing is **not permitted in tree structures**, but it is valid in graph structures. Such cyclic references can be used to represent complex model architectures, such as the recurrent connections in recurrent neural networks. Through this flexible referencing method, we can more naturally express the dynamic relationships between nodes in neural networks.\n",
    "\n",
    "### Inspecting the Graph Structure\n",
    "\n",
    "We can inspect the graph definition using `brainstate.graph.graphdef()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c572e954c97bd3da",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:07.575041Z",
     "start_time": "2025-10-10T15:54:07.568785Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NodeDef(\n",
       "  type=A,\n",
       "  index=0,\n",
       "  attributes=('b', 'self', 'w'),\n",
       "  subgraphs={\n",
       "    'self': NodeRef(\n",
       "      type=A,\n",
       "      index=0\n",
       "    )\n",
       "  },\n",
       "  static_fields={\n",
       "    'w': Array([[0.72766423, 0.78786755, 0.18169427],\n",
       "           [0.26263022, 0.11072934, 0.20263076]], dtype=float32)\n",
       "  },\n",
       "  leaves={\n",
       "    'b': NodeRef(\n",
       "      type=ShortTermState,\n",
       "      index=1\n",
       "    )\n",
       "  },\n",
       "  metadata=(<class '__main__.A'>,),\n",
       "  index_mapping=None\n",
       ")"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the graphdef of a\n",
    "graphdef = brainstate.graph.graphdef(a)\n",
    "graphdef"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f60e6d557d37001",
   "metadata": {},
   "source": [
    "## Building Hierarchical Neural Networks\n",
    "\n",
    "`brainstate.graph.Node` is the foundational class in `pygraph`, defining the structure and behavior of each node within the graph. The **attributes** of each node represent its leaf nodes and can accommodate any JAX array or other transformable data types. The **methods** associated with each node facilitate operations for manipulating, updating, and transforming the data contained within.\n",
    "\n",
    "In `brainstate`, any neural network module is a subclass of `brainstate.graph.Node`, allowing for the flexible construction of various neural network architectures. This design ensures excellent scalability while fully harnessing the powerful computational capabilities provided by JAX. For instance, we can easily define a multi-layer perceptron (MLP) module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a41868f76bec2452",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:08.564682Z",
     "start_time": "2025-10-10T15:54:07.598547Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP structure:\n",
      "MLP(\n",
      "  l1=Linear(\n",
      "    in_size=(2,),\n",
      "    out_size=(3,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[3]),\n",
      "        'weight': ShapedArray(float32[2,3])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l2=Linear(\n",
      "    in_size=(3,),\n",
      "    out_size=(4,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[4]),\n",
      "        'weight': ShapedArray(float32[3,4])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l3=Linear(\n",
      "    in_size=(4,),\n",
      "    out_size=(5,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[5]),\n",
      "        'weight': ShapedArray(float32[4,5])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  st=ShortTermState(\n",
      "    value=ShapedArray(float32[5])\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class MLP(brainstate.graph.Node):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.l1 = brainstate.nn.Linear(2, 3)\n",
    "        self.l2 = brainstate.nn.Linear(3, 4)\n",
    "        self.l3 = brainstate.nn.Linear(4, 5)\n",
    "        self.st = brainstate.ShortTermState(brainstate.random.rand(5))\n",
    "\n",
    "# Create MLP instance\n",
    "mlp = MLP()\n",
    "\n",
    "print(\"MLP structure:\")\n",
    "print(mlp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8757f357fcd874b4",
   "metadata": {},
   "source": [
    "## Traversing the Graph Structure\n",
    "\n",
    "Nodes can be created and manipulated using the following attributes and methods:\n",
    "\n",
    "- **Data Storage**: Each `brainstate.graph.Node` can store any JAX arrays or other transformable data types within its *attributes*.\n",
    "\n",
    "- **Node Connections**: The *attributes* of a `brainstate.graph.Node` can reference other `brainstate.graph.Node` instances, establishing complex dependency graphs -- as illustrated by the three linear layer modules within the `MLP` class above.\n",
    "\n",
    "- **Attributes and Their Paths**: Every `brainstate.graph.Node` includes a unique path for retrieval and transformation. This path aids in identifying and accessing nodes within complex structures, representing the node's position within the hierarchical nesting of the graph.\n",
    "\n",
    "### Iterating Over Leaf Nodes\n",
    "\n",
    "For instance, we can view the leaf data points in the `MLP` graph using the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "dc8d3b5e89e9863c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.396842Z",
     "start_time": "2025-10-10T15:54:10.392984Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Leaf nodes in MLP:\n",
      "============================================================\n",
      "('l1', '_in_size', 0)                    2\n",
      "('l1', '_out_size', 0)                   3\n",
      "('l1', 'weight')                         ParamState\n",
      "('l2', '_in_size', 0)                    3\n",
      "('l2', '_out_size', 0)                   4\n",
      "('l2', 'weight')                         ParamState\n",
      "('l3', '_in_size', 0)                    4\n",
      "('l3', '_out_size', 0)                   5\n",
      "('l3', 'weight')                         ParamState\n",
      "('st',)                                  ShortTermState\n"
     ]
    }
   ],
   "source": [
    "print(\"Leaf nodes in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "for path, leaf in brainstate.graph.iter_leaf(mlp):\n",
    "    leaf_info = leaf.__class__.__name__ if isinstance(leaf, brainstate.State) else leaf\n",
    "    print(f\"{str(path):<40} {leaf_info}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9b026896800605e",
   "metadata": {},
   "source": [
    "### Iterating Over All Nodes\n",
    "\n",
    "We can also view all nodes in the `MLP` graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "417155450f649750",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.424489Z",
     "start_time": "2025-10-10T15:54:10.418566Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "All nodes in MLP:\n",
      "============================================================\n",
      "('l1',)                        Linear\n",
      "('l2',)                        Linear\n",
      "('l3',)                        Linear\n",
      "()                             MLP\n"
     ]
    }
   ],
   "source": [
    "print(\"\\nAll nodes in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "for path, node in brainstate.graph.iter_node(mlp):\n",
    "    print(f\"{str(path):<30} {node.__class__.__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cefa4a93599dbcb0",
   "metadata": {},
   "source": [
    "## Common Functions in `pygraph`\n",
    "\n",
    "`brainstate` provides numerous utilities for operating on `pygraph`, which can be found in the `brainstate.graph` module.\n",
    "\n",
    "Overall, the commonly used functions in `pygraph` can be categorized into the following types:\n",
    "\n",
    "### Graph Structure Operations\n",
    "Functions for constructing and retrieving graph structures:\n",
    "- `brainstate.graph.graphdef`: View the graph structure\n",
    "- `brainstate.graph.iter_node`: Iterate through all child nodes in the graph structure\n",
    "- `brainstate.graph.iter_leaf`: Traverse all data points in the graph structure\n",
    "- `brainstate.graph.nodes`: View all nodes in the graph structure\n",
    "- `brainstate.graph.states`: View all `State` instances in the graph structure\n",
    "\n",
    "### Graph Structure Transformations\n",
    "Functions for transforming and manipulating graph structures:\n",
    "- `brainstate.graph.treefy_states`: Convert all `State` instances in the graph structure to pytree\n",
    "- `brainstate.graph.clone`: Copy the graph structure\n",
    "- `brainstate.graph.treefy_split`: Split the graph structure into a `graphdef` and a pytree representation of `State`\n",
    "- `brainstate.graph.treefy_merge`: Merge the `graphdef` and pytree representation of `State` into a graph structure\n",
    "- `brainstate.graph.flatten`: Flatten the graph structure into collections of `graphdef` and `State`\n",
    "- `brainstate.graph.unflatten`: Restore the flattened graph structure to its original form\n",
    "\n",
    "### Graph Structure Modifications\n",
    "Functions for modifying and updating graph structures:\n",
    "- `brainstate.graph.pop_states`: Remove `State` instances from the graph structure that meet certain conditions\n",
    "- `brainstate.graph.update_states`: Update `State` instances in the graph structure that meet certain conditions\n",
    "\n",
    "### Graph Structure Conversions\n",
    "Functions for converting between pygraph and pytree data structures:\n",
    "- `brainstate.graph.graph_to_tree`: Convert a graph structure to a pytree\n",
    "- `brainstate.graph.tree_to_graph`: Convert a pytree to a graph structure\n",
    "\n",
    "These functions provide developers with a rich set of tools to achieve flexibility and efficiency when constructing and manipulating graph structures."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c60fb47f5a07efd",
   "metadata": {},
   "source": [
    "## Splitting and Merging Graphs\n",
    "\n",
    "A fundamental operation is splitting a graph into its structure definition (`graphdef`) and its state values (`treefy_states`). This is crucial for working with JAX transformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "d0fbdc61e856ba1b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.448799Z",
     "start_time": "2025-10-10T15:54:10.442142Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph split into:\n",
      "  1. graphdef (structure)\n",
      "  2. tree_states (state values)\n",
      "\n",
      "Number of states: 4\n"
     ]
    }
   ],
   "source": [
    "# Split the graph structure into graphdef and treefy_states\n",
    "graphdef, tree_states = brainstate.graph.treefy_split(mlp)\n",
    "\n",
    "print(\"Graph split into:\")\n",
    "print(f\"  1. graphdef (structure)\")\n",
    "print(f\"  2. tree_states (state values)\")\n",
    "print(f\"\\nNumber of states: {len(tree_states)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "2ebfd319743f6d58",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.473793Z",
     "start_time": "2025-10-10T15:54:10.466509Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph definition:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "NodeDef(\n",
       "  type=MLP,\n",
       "  index=0,\n",
       "  attributes=('l1', 'l2', 'l3', 'st'),\n",
       "  subgraphs={\n",
       "    'l1': NodeDef(\n",
       "      type=Linear,\n",
       "      index=1,\n",
       "      attributes=('_in_size', '_name', '_out_size', 'w_mask', 'weight'),\n",
       "      subgraphs={\n",
       "        '_in_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 2\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_name': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_out_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 3\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        'w_mask': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        )\n",
       "      },\n",
       "      static_fields={},\n",
       "      leaves={\n",
       "        'weight': NodeRef(\n",
       "          type=ParamState,\n",
       "          index=2\n",
       "        )\n",
       "      },\n",
       "      metadata=(<class 'brainstate.nn.Linear'>,),\n",
       "      index_mapping=None\n",
       "    ),\n",
       "    'l2': NodeDef(\n",
       "      type=Linear,\n",
       "      index=3,\n",
       "      attributes=('_in_size', '_name', '_out_size', 'w_mask', 'weight'),\n",
       "      subgraphs={\n",
       "        '_in_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 3\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_name': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_out_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 4\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        'w_mask': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        )\n",
       "      },\n",
       "      static_fields={},\n",
       "      leaves={\n",
       "        'weight': NodeRef(\n",
       "          type=ParamState,\n",
       "          index=4\n",
       "        )\n",
       "      },\n",
       "      metadata=(<class 'brainstate.nn.Linear'>,),\n",
       "      index_mapping=None\n",
       "    ),\n",
       "    'l3': NodeDef(\n",
       "      type=Linear,\n",
       "      index=5,\n",
       "      attributes=('_in_size', '_name', '_out_size', 'w_mask', 'weight'),\n",
       "      subgraphs={\n",
       "        '_in_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 4\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_name': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        '_out_size': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(0,),\n",
       "          subgraphs={},\n",
       "          static_fields={\n",
       "            0: 5\n",
       "          },\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef((*,)),\n",
       "          index_mapping=None\n",
       "        ),\n",
       "        'w_mask': NodeDef(\n",
       "          type=PytreeType,\n",
       "          index=-1,\n",
       "          attributes=(),\n",
       "          subgraphs={},\n",
       "          static_fields={},\n",
       "          leaves={},\n",
       "          metadata=PyTreeDef(None),\n",
       "          index_mapping=None\n",
       "        )\n",
       "      },\n",
       "      static_fields={},\n",
       "      leaves={\n",
       "        'weight': NodeRef(\n",
       "          type=ParamState,\n",
       "          index=6\n",
       "        )\n",
       "      },\n",
       "      metadata=(<class 'brainstate.nn.Linear'>,),\n",
       "      index_mapping=None\n",
       "    )\n",
       "  },\n",
       "  static_fields={},\n",
       "  leaves={\n",
       "    'st': NodeRef(\n",
       "      type=ShortTermState,\n",
       "      index=7\n",
       "    )\n",
       "  },\n",
       "  metadata=(<class '__main__.MLP'>,),\n",
       "  index_mapping=None\n",
       ")"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the graphdef\n",
    "print(\"Graph definition:\")\n",
    "print(\"=\" * 60)\n",
    "graphdef"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "4126740c3a85f543",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.501202Z",
     "start_time": "2025-10-10T15:54:10.493332Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Tree states:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{\n",
       "  'l1': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[-0.42749512, -0.46094188, -1.1391693 ],\n",
       "               [-1.5929017 ,  2.1525893 ,  0.7923302 ]], dtype=float32),\n",
       "        'bias': Array([0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'l2': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[ 0.5922906 , -0.44910267, -0.2513227 ,  1.5940351 ],\n",
       "               [ 0.5387001 , -1.538589  , -1.8159095 , -1.4499966 ],\n",
       "               [-1.0424379 , -1.6111814 ,  1.5264777 ,  1.1270016 ]],      dtype=float32),\n",
       "        'bias': Array([0., 0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'l3': {\n",
       "    'weight': TreefyState(\n",
       "      type=<class 'brainstate.ParamState'>,\n",
       "      value={\n",
       "        'weight': Array([[ 0.90832496, -0.60686237,  0.29780853, -0.90121883,  0.36897153],\n",
       "               [ 0.19299707,  0.06321475, -0.4093656 , -0.25156304,  0.58528906],\n",
       "               [ 0.08484001,  0.17979762,  0.39775968,  1.2874871 ,  0.64356524],\n",
       "               [-0.87185615, -0.48536623,  1.1163356 ,  1.2303296 ,  0.23622379]],      dtype=float32),\n",
       "        'bias': Array([0., 0., 0., 0., 0.], dtype=float32)\n",
       "      },\n",
       "      tag=None\n",
       "    )\n",
       "  },\n",
       "  'st': TreefyState(\n",
       "    type=<class 'brainstate.ShortTermState'>,\n",
       "    value=Array([0.7928349 , 0.14884543, 0.46262634, 0.9012923 , 0.20398748],      dtype=float32),\n",
       "    tag=None\n",
       "  )\n",
       "}"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inspect the tree states\n",
    "print(\"\\nTree states:\")\n",
    "print(\"=\" * 60)\n",
    "tree_states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6711263fa16c3034",
   "metadata": {},
   "source": [
    "### Merging Back\n",
    "\n",
    "We can merge the `graphdef` and `tree_states` back into a complete graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "916d6b14d5004a24",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.542506Z",
     "start_time": "2025-10-10T15:54:10.537923Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Merged MLP:\n",
      "MLP(\n",
      "  l1=Linear(\n",
      "    in_size=(2,),\n",
      "    out_size=(3,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[3]),\n",
      "        'weight': ShapedArray(float32[2,3])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l2=Linear(\n",
      "    in_size=(3,),\n",
      "    out_size=(4,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[4]),\n",
      "        'weight': ShapedArray(float32[3,4])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  l3=Linear(\n",
      "    in_size=(4,),\n",
      "    out_size=(5,),\n",
      "    w_mask=None,\n",
      "    weight=ParamState(\n",
      "      value={\n",
      "        'bias': ShapedArray(float32[5]),\n",
      "        'weight': ShapedArray(float32[4,5])\n",
      "      }\n",
      "    )\n",
      "  ),\n",
      "  st=ShortTermState(\n",
      "    value=ShapedArray(float32[5])\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Merge the graphdef structure and treefy_states\n",
    "mlp2 = brainstate.graph.treefy_merge(graphdef, tree_states)\n",
    "\n",
    "print(\"Merged MLP:\")\n",
    "print(mlp2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c8f14de00956ed",
   "metadata": {},
   "source": [
    "## Collecting States\n",
    "\n",
    "We can collect all states or specific types of states from a graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "d28ce6624733022a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.568611Z",
     "start_time": "2025-10-10T15:54:10.564664Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All states in MLP:\n",
      "============================================================\n",
      "('l1', 'weight')               ParamState\n",
      "('l2', 'weight')               ParamState\n",
      "('l3', 'weight')               ParamState\n",
      "('st',)                        ShortTermState\n",
      "\n",
      "Total states: 4\n"
     ]
    }
   ],
   "source": [
    "# View all states in the graph structure\n",
    "states = brainstate.graph.states(mlp2)\n",
    "\n",
    "print(\"All states in MLP:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in states.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")\n",
    "\n",
    "print(f\"\\nTotal states: {len(states)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17d6f5468409d56d",
   "metadata": {},
   "source": [
    "## `filter` Filter Syntax\n",
    "\n",
    "It is noteworthy that most graph operation functions for retrieving `State` instances support the inclusion of a series of **`filter`** functions to select `State` instances that meet certain conditions. For example, to filter out all `ShortTermState` instances, you can use the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8d40272b2916a52",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.595632Z",
     "start_time": "2025-10-10T15:54:10.591320Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ShortTermState instances:\n",
      "============================================================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{\n",
       "  ('st',): ShortTermState(\n",
       "    value=ShapedArray(float32[5])\n",
       "  )\n",
       "}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Filter only ShortTermState instances\n",
    "short_term_states = brainstate.graph.states(mlp2, brainstate.ShortTermState)\n",
    "\n",
    "print(\"ShortTermState instances:\")\n",
    "print(\"=\" * 60)\n",
    "short_term_states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9572255d91de1f2",
   "metadata": {},
   "source": [
    "### Filter DSL (Domain-Specific Language)\n",
    "\n",
    "In general, a `filter` has the following function form:\n",
    "\n",
    "```python\n",
    "def predicate(path: tuple[Key, ...], value: Any) -> bool:\n",
    "    ...\n",
    "```\n",
    "\n",
    "Here, `Key` is a hashable and comparable data type (often a string), `path` is a tuple of `Key` representing the nested structure corresponding to `value`, and `value` is the value at that path. If the value should be included in the `filter`, the function returns `True`; otherwise, it returns `False`.\n",
    "\n",
    "However, to simplify the creation of `filter` functions, `brainstate` provides a small domain-specific language (DSL). This allows users to pass types, boolean values, ellipses, tuples/lists, etc., which are internally converted into the corresponding predicates.\n",
    "\n",
    "| Literal | Description |\n",
    "|---------|-------------|\n",
    "| `...` or `True` | Matches all values |\n",
    "| `None` or `False` | Matches no values |\n",
    "| `type` | Matches instances of type `type`, or values with a `type` attribute of `type` |\n",
    "| `'{filter}'` (str) | Matches values with a string `tag` attribute equal to `'{filter}'` |\n",
    "| `(*filters)` (tuple) or `[*filters]` (list) | Matches values satisfying any of the internal `filters` |\n",
    "\n",
    "For example, we can use the `filter` DSL to select all `ParamState` instances and other remaining `State` instances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "e5e14c2fa1dc9a6c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.622972Z",
     "start_time": "2025-10-10T15:54:10.612465Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ParamState instances:\n",
      "============================================================\n",
      "('l1', 'weight')               ParamState\n",
      "('l2', 'weight')               ParamState\n",
      "('l3', 'weight')               ParamState\n",
      "\n",
      "Other states:\n",
      "============================================================\n",
      "('st',)                        ShortTermState\n"
     ]
    }
   ],
   "source": [
    "# Split states into params and others\n",
    "params, others = brainstate.graph.states(mlp, brainstate.ParamState, ...)\n",
    "\n",
    "print(\"ParamState instances:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in params.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")\n",
    "\n",
    "print(\"\\nOther states:\")\n",
    "print(\"=\" * 60)\n",
    "for path, state in others.items():\n",
    "    print(f\"{str(path):<30} {state.__class__.__name__}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e9e8b84efaa6e4",
   "metadata": {},
   "source": [
    "## `pygraph` and JAX Transformations\n",
    "\n",
    "After converting `pygraph` to `pytree`, we can utilize JAX's function transformation capabilities to operate on it. For example, we can use JAX's `jit` function to compile models, use the `grad` function for automatic differentiation, and use the `vmap` function for batching operations.\n",
    "\n",
    "Let's demonstrate this with a complete training example. We'll define a simple MLP model with a counter to track how many times it's been called."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "c168978ecb23519f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:10.932142Z",
     "start_time": "2025-10-10T15:54:10.660584Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model created and split into:\n",
      "  - graphdef (structure)\n",
      "  - params_ (trainable parameters)\n",
      "  - counts_ (counter state)\n"
     ]
    }
   ],
   "source": [
    "# Define Linear layer\n",
    "class Linear(brainstate.nn.Module):\n",
    "    def __init__(self, din: int, dout: int):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(brainstate.random.randn(din, dout) * 0.1)\n",
    "        self.b = brainstate.ParamState(jnp.zeros((dout,)))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.w.value + self.b.value\n",
    "\n",
    "\n",
    "# Define custom Count state\n",
    "class Count(brainstate.State):\n",
    "    pass\n",
    "\n",
    "\n",
    "# Define MLP with counter\n",
    "class TrainableMLP(brainstate.graph.Node):\n",
    "    def __init__(self, din, dhidden, dout):\n",
    "        super().__init__()\n",
    "        self.count = Count(jnp.array(0))\n",
    "        self.linear1 = Linear(din, dhidden)\n",
    "        self.linear2 = Linear(dhidden, dout)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        self.count.value += 1\n",
    "        x = self.linear1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# Create model and split it\n",
    "model = TrainableMLP(din=1, dhidden=32, dout=1)\n",
    "graphdef, params_, counts_ = brainstate.graph.treefy_split(\n",
    "    model, brainstate.ParamState, Count\n",
    ")\n",
    "\n",
    "print(\"Model created and split into:\")\n",
    "print(f\"  - graphdef (structure)\")\n",
    "print(f\"  - params_ (trainable parameters)\")\n",
    "print(f\"  - counts_ (counter state)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9670aed54d1882",
   "metadata": {},
   "source": [
    "### Creating a Dataset\n",
    "\n",
    "We'll create a simple regression dataset: $y = 0.8x^2 + 0.1 + \\text{noise}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "ccfaad421f26c87a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:11.539566Z",
     "start_time": "2025-10-10T15:54:10.938175Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAArMAAAHaCAYAAAAT9MEUAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAUQNJREFUeJzt3QmcU+W5+PEnsyQzDDMMA8Pq1HEpFVyggnIRvVbF4kZrvbZW/Qtal7rWSt13axVrrdVWlLpQ7e2itrXWq1wsotSq3KJYlargwirLMCyzkJlJZjn/z/PSjJkhwyST5OS8ye/7+cR4TnKSk7wJ8+Q5z/scn+M4jgAAAAAWysv0DgAAAAB9RTALAAAAaxHMAgAAwFoEswAAALAWwSwAAACsRTALAAAAaxHMAgAAwFoEswAAALAWwSwAAACsRTALAAAAaxHMAgA8IRQKyXe+8x35whe+IGVlZfIf//Efsnjx4kzvFgCPI5gFAHhCW1ubVFdXy2uvvSZ1dXXy/e9/X6ZNmyY7duzI9K4B8DCCWcDDHn/8cfH5fLJ69WpXtwUyoaSkRG6++WaTmc3Ly5Nvf/vb4vf7ZcWKFZnetZyiPx70/f/Zz36W6V0B4kIwCyRAg8N4LosWLZJcFgmkI5eioiIZMWKETJ06VX7+859LY2Njnx/7jTfekFtvvdVk7jLNS/uSzKH9a665xoxPcXGxTJw4URYsWBDXth9//LEJOPfYYw/p16+f7LfffvLDH/5QmpqaUrJv+vjbtm2TfffdV7z6Hmjgd8stt8hxxx0nFRUV5vOun3+b/etf/xLHceSggw7K9K4AcfE5+okFEJff/OY3XZZ//etfmz96//3f/91l/bHHHitDhw5N+vna29ultbVVAoGA+SPp1rbJ0j/m55xzjgls9tprL7MfmzZtMkG+vl+aeXvuuef69MfynnvukauuukpWrVplDklnkpf2pa9OP/10+eMf/2gO6X/xi180Y/fmm2/KK6+8IocffniP261bt86M34ABA+TCCy80gZzWt+r2X/va1+Qvf/lLUvvV3NwsX/nKV+SEE04wwaIX3wOlRz70M66f6b333tt8xn/1q1/J2WefLbbq6OiQcDickX87gD7RYBZA31xyySX6YzDu++/YscPJBb/61a/M+/Lmm2/uctvChQud4uJiZ88993SampoSfuyf/OQn5rFXrVrlZJqX9qUv/vGPf5j919cR0dzc7Oyzzz7OpEmTdrvtHXfcYbb917/+1WX99OnTzfpt27b1eb/C4bBz4oknOmeccYbT0dGR0LZHHnmkM2PGDFfeA9XS0uJs3LjR/L9+3vWx9PPfF4nuO4CdKDMA0kQPP2tW44MPPpAzzjhDBg4caLI8a9askYsvvli+9KUvmUOagwYNkm9+85sxa1u7171GHvOTTz4xmZ/y8nKTGdMsaPdDu8lsqzTDNGHCBFMisM8++8gvf/nLzsdIxtFHHy033XSTeR+iM93xvC/6/JoJVZoNi5Qx6H0SeV+1zEGzcJpN1ezTkCFDTDb97bff7rzP+vXrzcx6zbDrffbff3+ZO3duXPsSsXz5clm7dm2v74lmH2NldvXA2cEHHyxHHHGEpINmI/Pz8+WCCy7oXKfjfe6555osq2Zfe9LQ0GCuux+BGD58uKm31FrX6PdSH1ffz2gvvfSSFBYWyhVXXNElK3jWWWeZ9/KJJ55Ie2YwmfdA6Wdj2LBhkilHHXWU/Od//qf57B5//PFSWloqI0eOlPvvv3+X+86bN89ku7VThH7/zzzzTNmyZcsu99PvwuTJkzuX9TN83nnnmcyzvjf6evUzq/+WdP83Q9frY2um/qSTTpJPP/00Ta8c+FxB1P8DSAMNqPTQ5Z133mmCEz18qbWWkVpDDX4eeugh80dGA1+tPezNt771LRNAzZo1y/wRe/TRR01A9uMf/zgl2/7zn/80NYAamNx2222mZEFLBiorKyUVNFi5/vrr5a9//aucf/75Zl0878spp5wiH330kfz+9783k1MGDx5sttX9+t///d+431c9LK5BzKWXXipjxoyRrVu3mhn0H374oQkea2pqTFsoDaT0PpHH1wBHgzgNhHe3LxGjR4+WI488stca6kMOOcQ8/vbt282Pnognn3zSjIXuW3daulFfXx/X+62BhQaY3eljjxo1ygQ30Q499FBz/c4770hVVVXMx9T3VT8z+p7oZ0R/POj7r+/59773PTOZK0KDKw2GHn74YVMysOeee5pAX78bGoD99Kc/7bzvd7/7Xdm4caO8+OKLUlCQ/j9RybwHXrBs2TJT66tdH/SH6cknnyyPPPKI+YGgPxwPPPDALiUxevtPfvIT+eyzz8znVt/rl19+uctjvvfee+bzrfS7oD9qdXz1u6r/VmiA/+c//7nL+OiPZ/0saCD8ox/9yPxA/sUvfiFTpkwx3z/9gQmkzb8ztABSXGZwyy23mNtOP/30LutjHVpfvHixue+vf/3rmIfrI4exI4/5ne98p8v9vvGNbziDBg1K2bbTpk1z+vXr56xfv75z3ccff+wUFBTEVVaxuzKDiAEDBjhf/vKXE35fejq0n8j7qs+tY9eTc8891xk+fLizZcuWLuu//e1vm20jz9VbmYHepoeOe/Pcc8+Z+2oJRvShdj3UrWMRyyuvvGK2iefS0/7tv//+ztFHH73L+vfff99sN2fOnN3u9+23325KRqKf64Ybboh5388++8wJBALORRddZN5XfW3jxo3rUnqzevVq8xhFRUVOSUlJ5+XVV1910nWoPtn3IJrbZQYbNmwwz1dZWemsW7euc/0HH3xg1j/xxBNmWd8/n8/n3HzzzV22nz17trnfkiVLOtfV1NSYdQ8++KBZvueee8x41NfX97gfy5Ytc/x+v/k8RHv33XfNY/3pT3+K+zUBfUFmFkgzzQJGi85QaHZNM306W1sPzWmmVLOWiT6mHobWTIk+VvcMU6LbahZWD/9+4xvfMBmfCN1HzaL9z//8j6RC//79u3Q1SPZ9SWR7XfePf/xDNmzY0OU1Ko1B//SnP5kMtv5/9GFY7cag2VJ9vOjDsD2Jd36tZmaVPq5m05RmMXVi2bPPPhtzm7Fjx8Y9476nw+A6yUoPk3enh5Ijt++OlkboIe7/+q//Mpm7F154wRyB0OfTjHY0zc5qZk+zhvo69bH/9re/dcngasY2kTnJsbLTuk67E3Q/fN5TdjrZ96CvUrHvmpVVmu3WoxERWrqhIqUed9xxh3n/te1ZtMhnWLPkkc+gZmVVJKOrnTp0v3TMNBsfix610QlwmlWP3nf9bum+rFy5MoF3BkgcwSyQZnpIP5r+cdRD/DrjWWsJo/94x3vYWP9wRIscmtbD1L0Fs71tu3nzZrOPsdohpbJFkrY00kOWqXpfEtn+7rvvlhkzZpjDx+PHjzd1ftOnTzc1gbW1teYPuAaTeolF36NU0uBPgw095K2CwaDcfvvt8v/+3/+TAw44IOY2Om56CDcZ+gNAg6fuWlpaOm/viQb1WmeqpRaRQEoPTWvNq7a50g4BGuBGu/LKK+WBBx4wAdPf//5385qT8frrr5ua0e603EH3L1pPHSeSeQ8yve+RYFZLB6JpcKq0flwP9y9cuFBmzpxpaoOj6edMRdc3Rx4zEszq9+Kxxx4z+6olOFrGo59LLUFS+t7pjxh9nujvczSt4wXSiWAWSLPufwwvu+wyE3Bp3eWkSZPMJCytzdQ/EhoIxKP7H6WIeLJayWybKlqvpwFmdHCc7PuSyPaadY1kpLVuV2sItf7zmWeekS9/+cvmPvoHWwPeWNLRf1MzY5Fg9t577zU/LjTj1RNtnaQ9WOOhdbyxxl0DEg38u9M6StU9ax3twQcfNO9VdEZQaVsurZ/U19I92NYMYeRMX5ptTFas7PQPfvAD8+MgMjmvt+x0Mu9BpvddfxREfghFe/fdd009q9aDa2Cr77cGtt1FJnDp/aIfU3/w6vdHab2//mDR74YeldHJm/q5nD9/vsnsatZVA1n98aV15j29ViCdCGYBl+nEIw2Soie9aBbIK433Nbuih1i7z1RWsdb1RaQvrx62T/R96Wl2e6LvqwYx2v1AL5pp1ayTBlt66FszSVpu0VvmM5Uz7TWY1d67OnNcJ+tcdNFF5rB7TzSDFyuzF0tPmb1x48aZXqrdy1O0BCNye090YlD0ZLUIPSStNICKpj8YdLKhZmY1WNP3WpeTESs7ret0bOPNWifzHmR63zWLGitQ1IBUJ7VF94mNVUqhGVf9jEVn/3Xb7j/WtCRIM7R60bOxadb26aefNsFspFRIJzsme6QA6CtacwEu0wxZ9yyozvrV4Mkr+6d/lLRWU2tKowNZnXGfLJ05rVkcLb/Q1kCJvi+RGsvuQWq82+ty97IDDeA1A6eHTPVxtAZU62b1TEjdaRlCb/uSaGsupTPGNYOsbdz0ddxwww1xZfbiufSU2Tv11FPN+xFdTqHvgWa49SxY0bP4NfumrydSE6nBkmZfNWsXTbs7aH1ndECkn6Vrr73WjPsll1xiyhP0hCMaZGdaMu9BJuk+a/eNWMGsZmYj77+21dMsrZY1RNNg9NVXXzUlIZGAVx9TOw9Eto31OvVIk94vkrHWH0m6vX5futMfNHqEAUg3MrOAy7T3omYm9TCeHt7TXpY64ap7fWEmaQ9VPfyumRfNEOofL82oaQZHWxXFS4PfyGFOzeRpIKvBlWaDNAsZmWSTyPuiNa5Kgz0tIdAJJtqWKN7tNZOkh8Y1iNFAQLNOej9tDRbJ6t51110mW6fBjE5a0sfTQ/o6CUbvGzm839O+RILceFtzRYJZpUGHvv+9tUFLRc2svj5tj3XdddeZ7LSWfWhvV21rplm7aEuWLDGZYJ1sFOmxq+Or5Ro62Uvf5+eff96s0zZckWBn6dKl5keLXiIB+tVXXy1z5sxJSXY2Wcm8BxH63dAfNJEff3o4XktpIuUvkUP2qaSn+tUjD92DWa0d1x+ekRIZ/SzqxCx9vzWo1UBVs87aM1knRer3u/tjRupldd/1B52WjuiPTy290KBfvz+Rlnr6Q1Dro3/3u9+Z7LZOEtV/L3QftDRB6397O4sakLQ+9UAAEHdrrtra2i7rt2/f7pxzzjnO4MGDnf79+ztTp051li9fbs6I1b0tT0/ttbo/Zvf7Jbut0jZR2jpLW+5oG6VHH33U+cEPfmDa9PQm8piRiz7GsGHDnGOPPda5//77nYaGhl22SeR90RZAI0eOdPLy8jr3Pd7tQ6GQc9VVVzljx451SktLTesn/f9IK6LoFkU6vlVVVU5hYaHZ/2OOOcZ5+OGHe92XRFtzRVRXV5s2S42NjY5b9GxXV155pXl92jrrkEMOcebPn99jKzD9HEWfPev444832+p7NGrUKHNmsNbWVnO7tovSFmeTJ082Z8qKpi26dJuVK1em9PX05SxaybwHSj9jibZFS3bfn3766ZhnYNM2W7r++eef71ynreQuvfRS893QVmr6eX/ooYd2Obta5DG1LZl67LHHTNsy/Uzq+7Lvvvs63/ve95zNmzd32U7H9s4773TGjBljHl9b/el7qO9TMBiM+/UDfeXT/yQfEgPIBTpr+v333zcZHKSWTqTRQ/c6+UtPOgAAiA81swBi6t5fUwPYyOkwkXp6mFvrD7v3AQYA7B41swBi0p6rZ599trles2aNOU2p9qPUekekhtZZao2p1tT+4Q9/MP8f3fMTANA7glkAMR133HFmZvqmTZtMWx/t3apnd9K+k0gNbWav3Qt0Qs0vf/nLLq3KAADxoWYWAAAA1qJmFgAAANYimAUAAIC1CGYBAABgrZybAKani9SztOi511N5XnUAAACkhk7p0jM26tkE9RTZu5NzwawGstHn2gYAAIA3rVu3znR82Z2cC2Y1Ixt5c8rKylzJBNfW1przrPf2ywLexBjajzG0H2NoN8bPfh0uj2FDQ4NJPkbitt3JuWA2UlqggaxbwWxLS4t5Lr7AdmIM7ccY2o8xtBvjZ7+ODI1hPCWhfKIAAABgLYJZAAAAWItgFgAAANYimAUAAIC1CGYBAABgLYJZAAAAWItgFgAAANYimAUAAIC1CGYBAABgLYJZAAAAWCvnTmcLAADgNduDYalrbpXy4kIZWOLP9O5YhWAWAAAgQ1pa2+X59zbIW6u3S1O4Tfr5C2RC9UA56aARUlSYn+ndswJlBgAAABmigeyCD2okz+eTEeXF5lqXdT3iQzALAACQodICzcgOKglIZWlAAgX55lqXl67ebm5H7whmAQAAMkBrZLW0oKy4a9WnLgfDbeZ29I5gFgAAIAN0spfWyDY0t3VZr8sl/gJzO3pHMAsAAJAB2rVAJ3ttDYaktjEkobZ2c63L46sH0tUgTnQzAAAAyBDtWqC0RnZDXbPJyB47ZmjnevSOYBYAACBDtP3WqeOr5Jj9htJnto8IZgEAADJMA1iC2L6hZhYAAADWIpgFAACAtTIazL766qsybdo0GTFihPh8Pnn22Wd73WbRokVy8MEHSyAQkH333Vcef/xxV/YVAAAA3pPRYDYYDMrYsWNl9uzZcd1/1apVcuKJJ8pRRx0l77zzjnz/+9+X8847T1588cW07ysAAICNtgfDsmpLMGvPKJbRCWDHH3+8ucRrzpw5stdee8lPf/pTszx69Gh57bXX5Gc/+5lMnTo15jahUMhcIhoaGsx1R0eHuaSbPofjOK48F9KDMbQfY2g/xtBujF9mtLS2y7xlG+WtNdulOdwmxf4CmbDnQDnhwOGmi4KXxzCR57Gqm8HixYtlypQpXdZpEKsZ2p7MmjVLbrvttl3W19bWSktLi7gxGPX19eYDkJdHibKNGEP7MYb2Ywztxvhlxuuf1Mq76+qkIlAo/crypSnULO9+3CC+lnqZvG+lp8ewsbExO4PZTZs2ydChQ7us02XNtjY3N0txcfEu21x33XUyc+bMzmW9b1VVlVRWVkpZWVna91kHX+uB9fn4AtuJMbQfY2g/xtBujJ/76oJhebOmRvICA6Swf0BaRaSwUKRdQvJWTYccNbZcyhNoBeb2GBYVFWVnMNsXOlFML93pQLj1hdLBd/P5kHqMof0YQ/sxhnZj/NxVH2qXpnC7jCgv1je/c31ZcaE505jeXlGa59kxTOQ5rPpEDRs2TGpqarqs02XNsMbKygIAAOSi8uJC6ecvkIbmti7rdbnEX2BuzxZWBbOTJk2ShQsXdlm3YMECsx4AAAA76dnEJlQPlK3BkNQ2hiTU1m6udXl89cCsOttYRoPZHTt2mBZbeom03tL/X7t2bWe96/Tp0zvvf+GFF8rKlSvl6quvluXLl8uDDz4oTz/9tFxxxRUZew0AAABedNJBI+TYMUPNpC0tLdBrXdb12SSjNbNvvfWW6RkbEZmoNWPGDHMyhI0bN3YGtkrbcr3wwgsmeL3//vtljz32kEcffbTHtlwAAAC5qqgwX04dXyXH7DdU6ppbTWlBNmVkPRHMfuUrXzG/EnoS6+xeus0///nPNO8ZAABAdhhY4s/KINbKmlkAAAAgGsEsAAAArEUwCwAAAGsRzAIAAMBaBLMAAACwFsEsAAAArEUwCwAAAGsRzAIAAMBaBLMAAACwFsEsAAAArEUwCwAAAGsRzAIAAMBaBLMAAACwFsEsAAAArEUwCwAAAGsRzAIAAMBaBLMAAACwFsEsAAAArEUwCwAAgN2qC4alpqHFXHtNQaZ3AAAAAN7U0touz7+3Qd5avU36te+QpmX1MqG6Qk46aIQUFeaLF5CZBQAAQEwayC74oEbyfD6pKPGba13W9V5BMAsAAIBdbA+G5a3V22VQSUAq+wekMC/PXOvy0tXbze1eQDALAACAXdQ1t0pTuE3KirtWpepyMNxmbvcCglkAAADsory4UPr5C6Shua3Lel0u8ReY272AYBYAAAC7GFjilwnVA2VrMCS1O0LS2tFhrnV5fPVAc7sX0M0AAAAAMWnXArV09TbZFgyLk++XY8cM7VzvBQSzAAAAiEnbb506vkqOHlUp6zZukqrhw6SitEi8hGAWAAAAu1Ve4pdwWZG59hpqZgEAAGAtglkAAABYizIDAACQs7Txv/ZL1TZTqZ6dn87HxucIZgEAQM5paW03p2TVM1zpiQG0n6q2odJZ+jrpyauPjV1RZgAAAHKOBpsLPqiRPJ9PRpQXm2td1vVefmzsimAWAADkFD38r1nTQSUBqSwNSKAg31zr8tLV283tXnxsxEYwCwAAcorWserh/7LirtWWuhwMt5nbvfjYiI1gFgAA5BSdkKV1rA3NbV3W63KJv8Dc7sXHRmwEswAAIKdoZwGdkLU1GJLaxpCE2trNtS6Prx6YVOeBdD42YqObAQAAyDnaWUBpHeuGumaTNT12zNDO9V59bOyKYBYAAOQcbZF16vgqOWa/oXH3go23b2xfHht9RzALAACslIqTEuh2vW3b176x8Tw2kkcwCwAArOL2SQkifWO1vZb2jdXJXLqsNAOLzGICGAAAsIqbJyWgb6z3EcwCAABruB1c0jfW+whmAQCANdwOLukb630EswAAwBpuB5f0jfU+glkAAGCNTASXOrFM+8Q6jmP6xuo1fWO9g24GAADAKm6flIC+sd5GMAsAAKySqeCSvrHeRDALAACsRHAJRc0sAAAArEUwCwAAAGsRzAIAAMBaGQ9mZ8+eLdXV1VJUVCQTJ06UJUuW7Pb+9913n3zpS1+S4uJiqaqqkiuuuEJaWlpc218AAIBssD0YllVbgl3OmhZrnddldALYU089JTNnzpQ5c+aYQFYD1alTp8qKFStkyJAhu9z/d7/7nVx77bUyd+5cOeyww+Sjjz6Ss88+W3w+n9x7770ZeQ0AAAA2aWltl+ff22BOC6xnU9OTUIytGiAiPnl3XV3nOu3nq+3OtHuEl2U0M6sB6Pnnny/nnHOOjBkzxgS1/fr1M8FqLG+88YZMnjxZzjjjDJPN/epXvyqnn356r9lcAAAA7KSB7IIPaiTP55MR5cXm+rf/WCu//b81XdbpffS+XpexzGw4HJalS5fKdddd17kuLy9PpkyZIosXL465jWZjf/Ob35jg9dBDD5WVK1fKvHnz5KyzzurxeUKhkLlENDQ0mOuOjg5zSTd9Dj1TiBvPhfRgDO3HGNqPMbQb4+cddcGwvLV6mwwq8Utl/51tzfoH8qUl1CY+n0hpIF8C+Xn/vs2Rpau3ydGjKqWsuMDVMUzkeTIWzG7ZskXa29tl6NChXdbr8vLly2NuoxlZ3e7www83b2hbW5tceOGFcv311/f4PLNmzZLbbrttl/W1tbWu1NrqYNTX15v91WAd9mEM7ccY2o8xtBvj5x01DS3Sr32HVJT4pbC91azraGuT6pI28/+Btkbpl7czPBwR6JBtwbCs27jJBLdujmFjY2N2njRh0aJFcuedd8qDDz5oamw/+eQTufzyy+X222+Xm266KeY2mvnVutzozKxOHKusrJSysjJXvsBa06vPxxfYToyh/RhD+zGGdmP80pNhrWtplfKiQilP4MQR/pKwNC2rl5aQTyr7B8y6UEG7rA42mczsPgWl4svfWSNb2xwSJ98vVcOHmcysm2OojQE8H8wOHjxY8vPzpaampst6XR42bFjMbTRg1ZKC8847zywfeOCBEgwG5YILLpAbbrgh5psbCATMpTu9r1tfKB18N58PqccY2o8xtB9jaDfGL32TtyYkMFGrorRIJlRXmHpYnfClQeqOULsUBQq0qkAaQ+3iy/NJQ3ObbA2G5dgxQ802kR8kbo1hIs+RsU+U3++X8ePHy8KFCzvX6Ruly5MmTYq5TVNT0y4vTgNipWlvAACAXJu8tSDBiVoa+GqQqrHThrpmc33mxC/Imf+xZ5d1eh+9r9dltMxAD//PmDFDJkyYYCZ0aWsuzbRqdwM1ffp0GTlypKl7VdOmTTMdEL785S93lhlotlbXR4JaAACAbKS9XzUjO6gkIJWlO486V5bujH+Wrt4ux+w3VAbGUXKgGdxTx1eZ+9c1t0p5cWHndsftP2yXdV6X0WD2tNNOMxOxbr75Ztm0aZOMGzdO5s+f3zkpbO3atV0ysTfeeKNJcev1+vXrTd2GBrJ33HFHBl8FAABA+mmQqaUFmpGNVlZcYLKpensiAajet/v9Y63zuoxPALv00kvNpacJX9EKCgrklltuMRcAAIBcotlSrZHVetZIRlbpcom/wNyei6jCBgAAsIBmTHWy19ZgSGobQxJqazfXujy+eqB1GdWsycwCAAAgPpEJWVojq6UFJf4CayZqpQvBLAAAgCV2N3krVxHMAgAAWMbGiVrpQjALAADgcostsqqpQzALAABgwdm7EBvdDAAAACw5exd2RTALAADg8tm7AgX55lqXtTOB3p6K51i1JZiSx7IJZQYAAACWnb0rWkuOly+QmQUAAHDx7F3RUnH2rudzvHyBYBYAAMDSs3dtd6F8wesIZgEAAFygh/31bF2O45jSAr1O9uxddf8uX9ByhWi6HAy3mduzHTWzAAAAlp69qzyqfKGyND+l5Qu2IDMLAACyXqIz/RO5f6KPrQHsXoNLUnLChIFpKl+wCZlZAACQtRKd6Z/I/b3SReCkf5cpaI2sli+U+AuSLl+wCcEsAADIWpGZ/johSmf66+F3XVZ6yD+Z+yf62DaVL9iEMgMAAJCVEp3pn8j93eoikEgJw8AUli/YhMwsAADISomeqCCR+6fzJAheKmGwAZlZAACQlRI9UUEi90/nSRBUrp8IIREEswAAIG308PjqrUHZ0eJ+v9NEZ/oncv90dhHgRAiJocwAAACkXPRh8uZwq4wItMo+Wxw5aexIT8/0T+T+6eoikO4ShmxDMAsAAFIueqb/8AHF4mtplZc+rBHx+Tw90z+R+6eriwAnQkgMwSwAAEip7ofJxXGkX3GhVEihyWJq8Od2ZlGfL5HnTOT+iT52PI+nJQyRNl+akdVAVksYNPNLVrYramYBAEBKRQ6TaxAWrayoQILhNnM7dk9LFTRwdRzHlBbodS6dCCERZGYBAEBK9XiYvIXD5PHK9RMhJILMLAAASKlYM/3rm1tlWwpm+ueaXD0RQiLIzAIAgJSLnum/sb5ZRgREpozmMDlSj2AWAACk9TD59qaQdATrZO8vjJS8PA4KI7UIZgEAQNro4fEBxQWyuT2Y6V1BluLnEQAAAKxFMAsAAABrEcwCAADAWtTMAgCAlJ35K9meqKl4DOQWglkAAJBUwNnS2i7Pv7fBnMJWz/ylJ0zQPrPahku7GsQjFY/R0/4huxHMAgCApALO1nZHFq3YLINKAjKivNic+WvBBzVmG23PFQ99XN0mmcdIZUAMe1AzCwAAesxwrtoSNNfRAWeez2cCTr1+4b2N8uw/PzNBaGVpQAIF+eZal/WECZFte3seDUCTeYye9k+XdT2yF5lZAADQa4Zz9PBS+df6hs6AU1WW5ktDS6t8WrtDRg8f0OUxyooLZENdszncr31md0fvo8+jAWhPj9FbuUD3gDiyf0oDYj15AyUH2YnMLAAAiCvD+fHmRhNgRhvcf2eAuHVHqMt6LRMo8ReYutXe6H00YNZt+voYkYC4+/7pcjDcZm5HdiKYBQAAvR7yH1paZILLLY1dg9ZwmyPDy4plR6hNahtDEmprN9dbgyEZXz0wrmyo3kdrW3Wbvj5GKgJi2IlgFgAA9JrhrCwLmHWbGnYNOE8+eISceNBwcRzHlAXo9bFjhpqJV/HS++o2fX2MVATEsBM1swAAIGaGM1JzqnT5i0P6y4Ejy+XDjQ0m4CzxF3QGnNotQOtS+9oSS7fXrgXJPEYk8NUa2e77h+xFMAsAAHbJcEbaYmk2VgNZzXBqYKgBZ099XPX/k82AJvMYqQiIYR+CWQAAkFCGMxVBazp5ff+QWgSzAACgCzKcsAnBLAAAiIkMJ2xAMAsAQAr1VE8KID0IZgEASNNZs3QiVWSmP4D0oM8sAABpPGuWrgeQPgSzAACk6axZuqwdAfR2AOlBMAsAQJrOmqXLwXCbuR1AehDMAgCQwrNmRdPlEn+BuR1AehDMAgCQorNm6VmyahtDEmprN9e6PL56IF0NgGwOZmfPni3V1dVSVFQkEydOlCVLluz2/nV1dXLJJZfI8OHDJRAIyKhRo2TevHmu7S8AALFo1wI9S5bjOOasWXodfdYsAFnYmuupp56SmTNnypw5c0wge99998nUqVNlxYoVMmTIkF3uHw6H5dhjjzW3/fGPf5SRI0fKmjVrpLy8PCP7DwBABGfNSh9698Kzwey9994r559/vpxzzjlmWYPaF154QebOnSvXXnvtLvfX9du2bZM33nhDCgt31h9pVhcAAK/w+lmzbAoM6d0LTwezmmVdunSpXHfddZ3r8vLyZMqUKbJ48eKY2zz33HMyadIkU2bwl7/8RSorK+WMM86Qa665RvLzY3+oQ6GQuUQ0NDSY646ODnNJN30OPdTkxnMhPRhD+zGG9mMMUxMYzlu2Ud5as12aw21SrIHhngPlhAOHpz0w7Ov4Pf/uennpwxqpKAnIiAFF0tDSJi99sEnEceSUg/dI2/4i89/BRJ4nY8Hsli1bpL29XYYOHdplvS4vX7485jYrV66Ul19+Wc4880xTJ/vJJ5/IxRdfLK2trXLLLbfE3GbWrFly22237bK+trZWWlpaxI3BqK+vNx8ADdZhH8bQfoyh/RjD5L3+Sa28u65OKgKF0q8sX5pCzfLuxw3ia6mXyftWem78drS0yqdrN8jeJSIDirW1WasMLBYpl3azfuVgn/QvoktEtn4HGxsbs/N0tvpGar3sww8/bDKx48ePl/Xr18tPfvKTHoNZzfxqXW50ZraqqspkdcvKylzZZ5/PZ56Pf4DtxBjajzG0H2OYnLpgWN6sqZG8wAAp7B8QDQ21Wq9dQvJWTYccNbZcytNYctCX8WvaGpQNoc0yfECxNEUdfe0oapeN9c2SV1IuQwaVpG2fkdnvoDYG8HwwO3jwYBOQ1tTUdFmvy8OGDYu5jXYw0FrZ6JKC0aNHy6ZNm0zZgt+/6xdROx7opTsdCLf+QdTBd/P5kHqMof0YQ/sxhn1XH2qXpnC7Oc2u+Hyd68uKC03nBb29ojTPU+M3sF9Aiv2F0tDSLpWln4crutzPX2hu57PgLje/g4k8R8Y+BRp4amZ14cKFXaJ+Xda62FgmT55sSgui6yg++ugjE+TGCmQBAIA3T+qgE9FWbQn2eKpfevfCijIDPfw/Y8YMmTBhghx66KGmNVcwGOzsbjB9+nTTfkvrXtVFF10kDzzwgFx++eVy2WWXyccffyx33nmnfO9738vkywAAwNMigeGCD2o6T7OrgawGhtoL183AMJEOBZEevUtXbzcZ5BJ/Ab174a1g9rTTTjMTsW6++WZTKjBu3DiZP39+56SwtWvXdkkza63riy++KFdccYUcdNBBJtDVwFa7GQAAcq91E+LnlcBQA1kNqgdph4LyYhNUR4Js7dMbjd69iIfP0WlpOUQngA0YMMDMyHNrAtjmzZvNxDVqe+zEGNqPMbS/pydjaPePlcj4+UvK5e6/fiR5OpGo9PP5LFo+oOHI1cftR7DqUR0ufwcTidf4FwEA0CVjpoGGZsz0Wpd1PbKHBot7DS7JSNBY19JqfihpmUM0XQ6G20yQDSSKYBYAYLJ1mpHVQ7+aMQsU5JtrXdbD0j1N0sllvU1gwq7Ki7w3EQ32s6rPLAAgPTQjphkz07qpW8ZM6yv1dg7/er8cw+vKPTQRDdmDzCwAwJOtm7yKcozkaNCvgavWyOoPJb2mQwGSQWYWAOCp1k02lWOoytKd2Vgtx9BZ97xXu0eHAqQamVkAgEHGLP5yDCYw2T0RDdmFzCwAwCBjllg5RiQjqyjHADKHzCwAoAsyZj3jFKuA95CZBQDAwjNpAdiJYBYAgARQjgF4C8EsAAB9oAEsQSyQedTMAgAAwFoEswAAALAWZQYAAFh00gbqdIGuCGYBAPC4ltZ2c7pcPfuYnrRBe91qizDtoKAT0oBcRpkBAAAep4Gsnmo4z+eTEeXF5lqXdT2Q6whmAQBwqURg1ZaguU50O83IDioJSGVpQAIF+eZal7XXbaKPB2QbygwAAPBwiYDWyOp2mpGNVlZcYE7aoLdTP4tcRmYWAAAPlwjoZC8NgBua27qs1+USf4G5HchlBLMAAGsPwXtdKkoENOuqmdytwZDUNoYk1NZurnV5fPVAsrLIeZQZAAAyJttn6aeqREDfD6UBsG5X4i+QY8cM7VwP5DKCWQBIAfp/JncIXjOVGvDpoXNdVqeOrxLbRZcIVJbm97lEQAN7fT+O2W8onzOgG4JZAEhCtmcW3TwEryIBn2YgNXCzPWCLlAhEAnTNyGogqyUCmllN9PXp/W1/T4BUo2YWAJJA/8/kD8FrgBdNl4PhNnN7NtAfNhq4Oo5jSgT0mhIBIHXIzAJAH+VCZtGGQ/BeR4kAkF5kZgGgj3Ils5guuTZLX1/PXoNLsu51AdYEsxs2cMgMAKLR/zN5HIIH4FqZwf777y+zZ8+WM844I+knBYBskOrJPbmIQ/AAXMvM3nHHHfLd735XvvnNb8q2bduSfmIAyAZkFlODQ/AA0p6Zvfjii+X444+Xc889V8aMGSOPPPKITJs2rc9PDADZgMwiAFjUzWCvvfaSl19+WR544AE55ZRTZPTo0VJQ0PUh3n777VTvIwB4Hv0/AcCS1lxr1qyRZ555RgYOHChf//rXdwlmAQAAALckFIlqacEPfvADmTJlirz//vtSWVmZvj0DACAGTh0MoE/B7HHHHSdLliwxJQbTp0+PdzMAAFIi3NYhz7z9mby1po5TBwNIPJhtb2+X9957T/bYY494NwEAIGXeXL1VXlrVIhUlRebUwdoGLdIWTSfhAchNcQezCxYsSO+eAADQg7pgWD7ZHJSKkn6cOhhAF5zOFgDgeXUtreZ0t2VFnDoYQFcEswAAzysvKpRAQb40tHDqYABdEcwCADyvvMQv+w4pkW3BkNQ2hkyWVq/11MHjqwdmXYmBdmxYtSVorgHsHk1iAQBWOKR6kDhFbbJ0TZ05dXCJvyDrTh3c0touz7+3Qd5avZ2ODUCcCGYBAFbwF+TJKQfvIceMHpa1fWY1kNUODYNKAnRsAOJEmQEAwCoawO41uCTrAlktKdCMrAay2rFBa4T1Wpe1YwMlB0BsBLMAgJzjxZpUzTZraYF2aIhGxwZg9ygzAADkDC/XpGrZhO6PlhZEeugqOjYAu0dmFgCQMyI1qXk+n6lJ1Wtd1vWZpmUTGlhvzZGODUCqEMwCAHKCDTWpmiHWDg2O45iODXqdbR0bgFSjzAAAkBMiNamake1ek6qBo96e6eynljpo1wI9Pa+bHRs0kM/WDhHIfgSzAICcYFNNqgaUbgSVXq4hBuJFmQEAICdQk2pXDTEQL4JZAEDOtMSiJtWuGmIgHpQZAABy5nB2pmpSvciGGmIgHmRmAQA5dzg7W88i1tca4mherCEGPB/Mzp49W6qrq6WoqEgmTpwoS5YsiWu7J598Unw+n5x88slp30cAyJZD7W7jcLY3P2fUECNbZLzM4KmnnpKZM2fKnDlzTCB73333ydSpU2XFihUyZMiQHrdbvXq1XHnllXLEEUe4ur8AYPuhdrdxONu7n7NIrbD+qNCxKPEX5GwNMeyV8czsvffeK+eff76cc845MmbMGBPU9uvXT+bOndvjNu3t7XLmmWfKbbfdJnvvvber+wsA2XKo3a2MMoez3c3SJ/I5i9QQX33cfnLFsV8y17qczT+ukH0ympkNh8OydOlSue666zrX5eXlyZQpU2Tx4sU9bvfDH/7QZG3PPfdc+fvf/77b5wiFQuYS0dDQYK47OjrMJd30OXS2rBvPhfRgDO3nhTGsM4fat8mgEr9U9t+Zhdx57cjS1dvk6FGVUu7R7KRm+uYt2yhvrdkuzeE2KdZM354D5YQDh8cV9Awo1vuXy0sf1pjXW1ZUIA0tbbItGJIpo4ea23sbm0yPoY5fXUurlBcVpmSckn1PU/050zHQi0rHe5zp8YN9Y5jI82Q0mN2yZYvJsg4dOrTLel1evnx5zG1ee+01eeyxx+Sdd96J6zlmzZplMrjd1dbWSktLi7gxGPX19eYDoIE67MMY2s8LY1jT0CL92ndIRYlfCttbO9ePCHTItmBY1m3cJOGyIvGi1z+plXfX1UlFoFD6leVLU6hZ3v24QXwt9TJ538q4HuPQYQXiaymSTzYHJbSjXQYU5Mv4vUrkkGEFsnnzZs+OYbitQ95cvXXnfre1m3rffYeUyCHVg8RfkJfR99Smz5kXvoOwawwbGxvtqZlN9IWdddZZ8sgjj8jgwYPj2kazvlqTG52ZraqqksrKSikrKxM3Bl8nqenz8QW2E2NoPy+Mob8kLE3L6qUl5JPK/oHO9bXNIXHy/VI1fJgnM7Oa6XuzpkbyAgOksH9ANDwqLBRpl5C8VdMhR40tj3u/9xgxrM8ZzkyN4TNvfyYvrWqRipJ+Uta/QOpb2mTBqhZxitrklIP3yPh7asvnzAvfQdg1htoUwIpgVgPS/Px8qanRQ0+f0+Vhw4btcv9PP/3UTPyaNm3aLmnogoICM2lsn3326bJNIBAwl+50INz6Qungu/l8SD3GMDfHMJXnq68oLZIJ1RWmdlHEZyY/ac3o1mDYTLjR272oPtQuTeH2nZO3fL7O9WXFhWbCkN5eURr/e6qvs6+v1e3voenCsKZOKkqKTPcFVVmofzZ9snRNnRwzelifPhepfk9t+Zzx76j9fC6OYSLPkdFg1u/3y/jx42XhwoWd7bU0ONXlSy+9dJf777fffrJs2bIu62688UaTsb3//vtNxhUAvNp1wMaZ49GTtypL812dvBX9YyJSz5kNXRjS/Z7a+DkDkpHxMgMtAZgxY4ZMmDBBDj30UNOaKxgMmu4Gavr06TJy5EhT+6op5wMOOKDL9uXl5ea6+3oASHY2uPZB1UBGg4ydmS4xM71z6exTkV6kkdf/eaYvZAKknvY/max2zB8Te5abuls3pSvo7Ot7ms2fM8DqYPa0004zk7Fuvvlm2bRpk4wbN07mz5/fOSls7dq1HJIAkLEG/yoSyGimSwOEZAMD3T7Zx0hlCUQqM32pyGrH+jGhnRB0ApnW3bolnUGnG9nTVHzOABtkPJhVWlIQq6xALVq0aLfbPv7442naKwC5yOsN/jNx4oVEMn3JZrV7/jHhmI4COnnKzbrPdAWdZE+BLAtmAcArMlkjmskSiFRk+lKR1e7xx0RRgWnppZ0Q3Axm0x10kj0FksfxewCw5Hz13YNF7Xmq17qswWIqzyLVF5FAVLPY0XQ5GG4zt/f5bGEtbeb1akuvTNBx32twCYEn4EEEswDQjR5C1kPJ2hxcDy3rtRdmg8cTLKbj9Khunra2px8TerYwPVmBF3vxAsgsygwAwJJ6xt2VQAQK8uTvH9XKBxsbXKulTdeEqVh1qnraWz1bGAB0x78MAOCResbeOhTsLlgc2K9QXv90S0ZqaVM9YSrWjwntMxvPaW8B5B6CWQDIsEQ6FMQKFg/bZ5D8a31DWtuJZSKrHf1jInK2RwDojmAWADIskQ4FsYJFvV66ZrsM6u/3TDsxZukDcAsTwAAgg/raoSB6dn0qJl4BgK0IZgHA8nZWXm4nZptMdoMA0DeUGQBAFpykwY3To2azTJxZDUBqEMwCOai3WfNwT6raWXm1nZgtMnlmtWTxfUauI5gFcgjZJ29KZVaViVeJS8VpeDOB7zOwE8EskENszj5lM7Kqmc1aRuqW9TvhlW4Q8eD7DOxEMAvkCFuzT7mErGpmspapqlt2E99n4HN0MwByRCpmzQO2iGQt83w+k7XUa13W9dnQDYLvM/A5glkgR9jei5SWSUhn717N2GqdsuM4prRAr73cDcL27zOQSpQZADkyuzhVs+bdxiQXJKovNbC21S3b+n0G0oFgFsihwMvGXqRMckGikqmBtalu2cbvM5AOBLNADgVetmWfmOSCvsiVrKVt32cgXaiZRUzUJ7pTp5cp+gdvr8Elnv/DxyQX9JVtNbC58H0G0oXMLKw6TO5ltvaq9DIbWybBG8haArmDzCz63M4GXTG7OPVsbJkEbx1RImsJZD8ys+hEfWJycqVOz21MckFvOKIE5DaCWXTiMHnyCLxSj8PFsH3iJYD0IphFJ+oTk0fglT42tUyCeziiBICaWXSiPjF1qNNDKtFdpGd0vABAZhZdcJgc8A5qQXvHESUABLPogsPksI0XTx2cKtSC9o6JlwAIZhET9YnwumzPWlILGj+OKAG5jWAWgJWyPWtJd5H4cUQJyG1MAANgHZtOHdxXnIQjcUy8BHITwSwA62bY58IMdrqLAEB8KDMAYF2taq7MYKcWFAB6RzALwLpa1VyZwW5LLWg2d5QA4H0EswCsnGGfS1lLr3YX8UKWHgAIZgFYOcPelqxlNvNClh4AmAAGwOoZ9sxgz4xc6CgBwA4Es4BHZbJbQE+YYY9c6igBwA6UGQAe4/U6xFyqVc0FfZ28lSsdJQB4H8Es4DFer0OkVjU7gtZkfzTlSkcJAN5HMAt4iJe6Bdg6wx5d9RS0trY7smjF5qR+NJGlB+AFBLOAh3ipWwCyN9P/wnsbzedsn8rSpH40kaUH4AVMAAM8xIvdApB9HQdKAgWysb5F/AV5KZm8RUcJAJlEMAt4CN0C4EbHgcH9d36Otu4IdVnPjyYANiKYBTxG6w217tBxHFNaoNfUISKVmf5wmyPDy4plR6iNH00ArEfNLOAx1CEiVXbXceDkg0dIYX4ek7cAWI9gFvAougUgFXbXcUB/OPGjCYDtCGYBIIcz/fxoAmA7glkAyAEErQCyFRPAAAAAYC2CWQAAAFiLYBYAAADWIpgFAACAtTwRzM6ePVuqq6ulqKhIJk6cKEuWLOnxvo888ogcccQRMnDgQHOZMmXKbu8PAACA7JXxYPapp56SmTNnyi233CJvv/22jB07VqZOnSqbN2+Oef9FixbJ6aefLq+88oosXrxYqqqq5Ktf/aqsX7/e9X0HAABAZvkcPVdmBmkm9pBDDpEHHnjALHd0dJgA9bLLLpNrr7221+3b29tNhla3nz59+i63h0Ihc4loaGgwj799+3YpKyuTdNPXU1tbK5WVlZKXl/HfDugDxtB+jKH9GEO7MX7263B5DDVe0/iuvr6+13gto31mw+GwLF26VK677rrOdfoGaemAZl3j0dTUJK2trVJRURHz9lmzZsltt922y3odkJaWFnFj8HUg9DcDX2A7MYb2YwztxxjajfGzX4fLY9jY2Bj3fTMazG7ZssVkVocOHdplvS4vX748rse45pprZMSIESYAjkUDZS1j6J6Z1V8WbmVmfT4fv0YtxhjajzG0H2NoN8bPfh0uj6HOo8qJM4Dddddd8uSTT5o62p5edCAQMJfudCDc+kLp4Lv5fOm0PRjOyfO4Z9MY5irG0H6Mod0YP/u5OYaJPEdGg9nBgwdLfn6+1NTUdFmvy8OGDdvttvfcc48JZl966SU56KCD0rynaGltl+ff2yBvrd4uTeE26ecvkAnVA+Wkg0aYc79nk1wN2AEAsFFGg1m/3y/jx4+XhQsXysknn9yZxtblSy+9tMft7r77brnjjjvkxRdflAkTJri4x7lLA9kFH9TIoJKAjCgvlobmNrOsTh1fJdkcsJ9wwO5/WAEAgMzJeK5f61m1d+wTTzwhH374oVx00UUSDAblnHPOMbdrh4LoCWI//vGP5aabbpK5c+ea3rSbNm0ylx07dmTwVWQ3zVRqgKeBbGVpQAIF+eZal5eu3m5uz6aAPc/nMwG7XuvyvGUbM71rAADAqzWzp512mukscPPNN5ugdNy4cTJ//vzOSWFr167tUjfx0EMPmS4Ip556apfH0T61t956q+v7nwuHz3UbzVRqgBetrLhANtQ1m9ttPxzfPWBXlaU7yyeWrtku4wZXyJAM7yMAAPBgMKu0pKCnsgKd3BVt9erVLu1V9ki23lWDX91GSwsiAZ7S5RJ/gbnd9uB+dwH7xromCYbbM7afAADA48FsNqsLhqWmoUX8JWGpKI2/zYSX6l0166rBb2QbDfD0MbYGQ3LsmKHWZWVjBfejh5ea8olYAbveXuLPrkluAABkC4LZtAdM26Rf+w5pWlYvE6orXJ/9v9vD56u3yzH7xReM6n5HttHSghJ/gQlkI+ttKq+IFdy/8elWGdiv0ATouwTso4dI/6L0ZZ/pngAAQN8RzKbJ5wGTXypK/NIS2jmZyO3Z/6mqd9UAXPdbg990BF7pCOh6ysD+a31DzOC+tb1dJu8zWD7c2NAlYNduBg3bt6Zkn3K13RkAAOlCMJvubGh/vxS2t0plfw2cfAllQ1Mh1fWuut+p3Pd0BnQ9lVfsCLXJf+w9KEZw3ypHjKqUaWNHdAmstV1cQ5KvM5H9y6Z2ZwAAZH1rrmwUyYZqgBRNl4PhNnO7WyL1rnq4vLYxJKG2dnOty+OrB2b8sHZP7bB0fTraiQ0tLTJB45bGneUEsYJ7fU/2GlyS1vcmV9qd9YW+9lVbgjn9HgAA4kdmNt3Z0P7+jM/+92q9a6rqeRMpr6gsC5gfFZsaQuIvyM/YZLZcaHeWKMouAAB9QTCbBl1n/zsyItAhtc2aDQ1nZPZ/uutdvRTQRWpvxXF6LK/44pD+cuDI8l1qY90M7rO13VkyKLsAAPQFwWzas6HbZFswLE6+P+PZ0FTXu3opoIuV1dMfEpv/XU7QPQOrwVEmuwhkW7szL2fpAQDZjWA2zdnQo0dVyrqNm6Rq+LCM9Zn1qlQGdLGyehrIDu7vF8dxYmZgMx3ce7X8w02RHxT1TWHKLgAAfUIwm2blJX4JlxWZa6QnoNtdVk8D2Qv+c28Rn88z5RVeL/9wQ/dMen6ez5xcpNifL8PKPg9oc7nsAgAQH4LZLGRTE/5UBHS91d5qIKvdCbwq0xniTIiVSddTBr+/vkHyfXk5X3YBAIgfwWwWsXk2eDIBXS5NprLph0qimfQxw8tkzZYm8zneEWrNybILAEDiCGazSK7OBs+FyVS2/FCJJ9juKZNuzpTX2i7TD6uWAcWFVgfsAAD3EMxmiVyfDZ7tk6m8/kMlkWC7t0z6nhX90vZZrQuGpT7UTqAMAFmEYDZL5HoT/myeTBXPDxWVydedSLCdiUy6Btuvf1Irb9bUSFO43bOZbQBA4ghms0Qu1Y3m2mSq3f1QWbetSZ56c62s3tqUsfKDvhwVcDuTPm/ZRnl3XZ3kBQZ4MrMNAOg7gtksmeCTC3WjuWp3P1S2NIblH6u2yfABxRkL0vpyVMDNTLoJttdsl4pAoRT2D5juFsmU4GTDJDwAyCYEs1k0wSfb60ZzVU8/VDbWa9sxxwSymayTTuaogBuZdA08m/U7VZYvrUmU4NgyCQ8Acg3BbBZN8MnmutFcF+uHyqF7Vci/1teboCyTddJePCoQnT3VS7G/QJpCzVIYFVcnWoLj9Ul4AJCrCGazsBNBNtaN5rpYP1TUj7cu90SdtFeOCvSUPR27R7l8sLJB2iUkZcWFCQfbud4tBAC8jGDWA3K9EwGkzz9UvJIR9cpRgZ6yp18ZNVjGVZXLWzUdfQq2+Y4CgHcRzHoAnQhge0bUC0cFdpc9fe+zejl3fIUcNbayT31m+Y4CgHcRzHqAF2sOYQevZES9YHfZ0411TRIMt8veJX6pKM1L+LH5jgKAdxHMeoTXMmywC3XSu8+e6voSf3IdB/iOAoA3Ecx6BBk2b9vR0ipNW4MysF8g58bFlr6qu82ejh4i/YuSKwXgOwoA3kQw6zFk2LzFzI5/d718unaDbAhtlmJ/Yc70FrWxr2pP2dMTDhgmDdu3puQ5+I4CgLcQzFrMloyZzTSYe+nDGtm7RMzJCRpa2nOmt6iNfVV7yp52dHRIQ6Z3DgCQFgSzFrIxY2ajyOz4ipKADChulab8fKksLciJ3qK291UlewoAuSPxab3wTMYsz+czGTO91mVdb0ugtGpL0FzbMDu+rGjXM2wFw23m9mzV+dqLc++1AwDsQmbWMjZnzGzLKHfOjm9pk4FR3Z5yobcofVUBALYgM2sZmzNmtmWUI7PjtwVDUt/cKqG2dqltDJnZ8eOrB3r2R0MqX7u+Vn3NufTaAQB2IZi1OGMWzesZs+4Z5UCB1p8GzLJmlL1acqBZ4ymjh4ojIhvrm8VxnJzpLaqvUV+rvmbtDJBLrz0Xy2oAwFaUGVjG1jMR2Xpuey1/OOXgPWTlYJ/klZTv0mc2mztK0Fc1t8pqAMBWBLMWsvFMRLbXYGrD/SGDSiQvLy/nAhU6A+ROazMAsBHBrIVszJjZmlHuCYEKsnWiJgDYhppZi+kfw70Gl1jzRzFbajBtrf+Fe2yeqAkAtiEzC9fYmFHOpvpfuMf2shoAsAmZ2QzJ5RnOtmWUs6WjBNxDazMAcA+ZWZfl0sShbJVt9b9IDxsnagKAjQhmXcbEoexAoJJd0tFiLVvKagDA6whmXcQM5+xBoJId3DhSQmszAEgvamZdxAzn7GN7/W+us+0UywCAXRHMuoiJQ4B30GINALIDwayLmOGcuFzu+oD04kgJAGQHamZdxsSh+ND1AelGL1gAyA4Eszk+cSgds7hTga4PSDdarAFAdiCYzZBMz3D2cuaTrg9wC0dKAMB+BLM5ysuZT04Xi1w9UgIASBwTwHKQ12dx0/UBbqPFGgDYi2A2B3l9FjddHwAAQLwIZnOQDZlPrVnU2kXHcUxpgV5TywgAALqjZjYH2TCLm1pGAABgTWZ29uzZUl1dLUVFRTJx4kRZsmTJbu//hz/8Qfbbbz9z/wMPPFDmzZvn2r5mC1synz3VMnIyBQAA4InM7FNPPSUzZ86UOXPmmED2vvvuk6lTp8qKFStkyJAhu9z/jTfekNNPP11mzZolJ510kvzud7+Tk08+Wd5++2054IADMvIabGRr5tPLLcUAAEAOZmbvvfdeOf/88+Wcc86RMWPGmKC2X79+Mnfu3Jj3v//+++W4446Tq666SkaPHi233367HHzwwfLAAw+4vu/ZwLZZ3JGWYnk+n2ndpde6rOsBAEDuyWhmNhwOy9KlS+W6667rXJeXlydTpkyRxYsXx9xG12smN5pmcp999tmY9w+FQuYS0dDQYK47OjrMJd30OfQQvhvPle3qTEuxbTKoxC+V/XcG3zuvHVm6epscPapSytMQlDOG9mMM7ccY2o3xs1+Hy2OYyPNkNJjdsmWLtLe3y9ChQ7us1+Xly5fH3GbTpk0x76/rY9FyhNtuu22X9bW1tdLS0iJuDEZ9fb35AGigjr6raWiRfu07pKLEL4Xtn7cPGxHokG3BsKzbuEnCZUUpf17G0H6Mof0YQ7sxfvbrcHkMGxsb7amZTTfN+kZncjUzW1VVJZWVlVJWVubK4Pt8PvN8fIGT4y8JS9OyemkJ+aSy/87T3Kra5pA4+X6pGj4sbZlZxtBujKH9GEO7MX7263B5DHWSvxXB7ODBgyU/P19qana2iIrQ5WHDhsXcRtcncv9AIGAu3elAuPWF0sF38/myVUVpkUyorvh3SzFfVEuxsOnEoLenC2NoP8bQfoyh3Rg/+/lcHMNEniOjnyi/3y/jx4+XhQsXdon8dXnSpEkxt9H10fdXCxYs6PH+yJx0tM+ypaUYAABwR8bLDLQEYMaMGTJhwgQ59NBDTWuuYDBouhuo6dOny8iRI03tq7r88svlyCOPlJ/+9Kdy4oknypNPPilvvfWWPPzwwxl+JXCjfZatLcUAAECWBrOnnXaamYx18803m0lc48aNk/nz53dO8lq7dm2XVPNhhx1mesveeOONcv3118sXv/hF08mAHrPea581qCRg2mdpKUDkbGMaiKaCBrAEsQAAwOfocdocohPABgwYYGbkuTUBbPPmzeYEELlQJ6QlBT+ev9z0f60sjZqk1RgyJQFXH7efdUForoyhjl22ZrtzZQyzGWNoN8bPfh0uj2Ei8VrGM7PILhoMaWmBZmSj6WQtrXHV27MtULIdZ1UDANiMn0dIKc3qaTCkpQXRdLnEX2Buh7dwVjUAgM0IZpFSmnXVrN7WYMiUFoTa2s21Lo+vHkhW1oOlBZqR1fpmLQsJFOSba11eunp7SjtRAACQDgSzSDnaZ9lXFqJlINF0ORhuM7cDAOBl1Mwi5WifZWdZSGVpftrKQrJ5chkAILMIZpE2tM/yfkAXKQuJtE77/KxqIZNNT3Y/mVwGAEg3glkgzTIZ0MUTQEfKP7RGVstCSvwFKSsLcaPnMAAgtxHMWsKLWT14N6BLJIBOV1lI98llKlLKoIGzPh+fZQBAsghmPY7DtHbLVEDXlwA61WUh9BwGALiBbgYeRw9Qu2WiW4BX2m3RcxgA4AaCWQ/zSlACuwI6r7TboucwAMANBLMe5pWgpDcaVK/aEiS49khA56WMKD2HAQDpRs2sh7nVA7SvqOeNTzq7BWSi3VYi6DkMAEg3glkP81JQEgttl7wb0LkdQPeGnsMAgHQhmPU4rwUlEbRd8nZAR0YUAJArCGY9zqtBCW2X7EBGFACQ7QhmLeG1oMTr9bwAACA30M0AfULbJQAA4AVkZpF19bwAACB3EMwi6+p5AQBA7iCYRdbV8wIAgNxBzSwAAACsRTALAAAAaxHMAgAAwFoEswAAALAWwSwAAACsRTALAAAAaxHMAgAAwFoEswAAALAWwSwAAACsRTALAAAAaxHMAgAAwFoFkmMcxzHXDQ0NrjxfR0eHNDY2SlFRkeTl8dvBRoyh/RhD+zGGdmP87Nfh8hhG4rRI3LY7ORfM6kCoqqqqTO8KAAAAeonbBgwYsLu7iM+JJ+TNsl8WGzZskNLSUvH5fK78stDAed26dVJWVpb250PqMYb2YwztxxjajfGzX4PLY6jhqQayI0aM6DUTnHOZWX1D9thjD9efVweeL7DdGEP7MYb2YwztxvjZr8zFMewtIxtB4QoAAACsRTALAAAAaxHMplkgEJBbbrnFXMNOjKH9GEP7MYZ2Y/zsF/DwGObcBDAAAABkDzKzAAAAsBbBLAAAAKxFMAsAAABrEcwCAADAWgSzKTB79myprq425yueOHGiLFmyZLf3/8Mf/iD77befuf+BBx4o8+bNc21fkfwYPvLII3LEEUfIwIEDzWXKlCm9jjm89z2MePLJJ83ZAE8++eS07yNSO4Z1dXVyySWXyPDhw80M61GjRvHvqUXjd99998mXvvQlKS4uNmeWuuKKK6SlpcW1/UVXr776qkybNs2ccUv/TXz22WelN4sWLZKDDz7YfP/23XdfefzxxyUjtJsB+u7JJ590/H6/M3fuXOf99993zj//fKe8vNypqamJef/XX3/dyc/Pd+6++27ngw8+cG688UansLDQWbZsmev7jr6N4RlnnOHMnj3b+ec//+l8+OGHztlnn+0MGDDA+eyzz1zfd/RtDCNWrVrljBw50jniiCOcr3/9667tL5Ifw1Ao5EyYMME54YQTnNdee82M5aJFi5x33nnH9X1H4uP329/+1gkEAuZax+7FF190hg8f7lxxxRWu7zt2mjdvnnPDDTc4zzzzjHa5cv785z87u7Ny5UqnX79+zsyZM00884tf/MLEN/Pnz3fcRjCbpEMPPdS55JJLOpfb29udESNGOLNmzYp5/29961vOiSee2GXdxIkTne9+97tp31ekZgy7a2trc0pLS50nnngijXuJVI+hjtthhx3mPProo86MGTMIZi0bw4ceesjZe++9nXA47OJeIlXjp/c9+uiju6zToGjy5Mlp31f0Lp5g9uqrr3b233//LutOO+00Z+rUqY7bKDNIQjgclqVLl5rDzBF5eXlmefHixTG30fXR91dTp07t8f7w3hh219TUJK2trVJRUZHGPUWqx/CHP/yhDBkyRM4991yX9hSpHMPnnntOJk2aZMoMhg4dKgcccIDceeed0t7e7uKeo6/jd9hhh5ltIqUIK1euNCUiJ5xwgmv7jeR4KZ4pcP0Zs8iWLVvMP5z6D2k0XV6+fHnMbTZt2hTz/roedoxhd9dcc42pMer+pYZ3x/C1116Txx57TN555x2X9hKpHkMNfl5++WU588wzTRD0ySefyMUXX2x+WOpZiuDt8TvjjDPMdocffrgeIZa2tja58MIL5frrr3dpr5GsnuKZhoYGaW5uNrXQbiEzCyThrrvuMhOI/vznP5tJD/C+xsZGOeuss8xEvsGDB2d6d9BHHR0dJrP+8MMPy/jx4+W0006TG264QebMmZPpXUMcdOKQZtIffPBBefvtt+WZZ56RF154QW6//fZM7xosRGY2CfqHMD8/X2pqarqs1+Vhw4bF3EbXJ3J/eG8MI+655x4TzL700kty0EEHpXlPkaox/PTTT2X16tVm1m50YKQKCgpkxYoVss8++7iw50jme6gdDAoLC812EaNHjzbZIj3s7ff7077f6Pv43XTTTeZH5XnnnWeWtbNPMBiUCy64wPwo0TIFeNuwHuKZsrIyV7Oyik9LEvQfS80ILFy4sMsfRV3WWq5YdH30/dWCBQt6vD+8N4bq7rvvNhmE+fPny4QJE1zaW6RiDLUt3rJly0yJQeTyta99TY466ijz/9oiCN7/Hk6ePNmUFkR+iKiPPvrIBLkEst4fP51r0D1gjfww2Tn/CF43yUvxjOtTzrKwHYm2F3n88cdNa4oLLrjAtCPZtGmTuf2ss85yrr322i6tuQoKCpx77rnHtHW65ZZbaM1l2RjeddddpgXNH//4R2fjxo2dl8bGxgy+ityW6Bh2RzcD+8Zw7dq1povIpZde6qxYscJ5/vnnnSFDhjg/+tGPMvgqclei46d/+3T8fv/735sWT3/961+dffbZx3T8QWY0NjaalpN60fDw3nvvNf+/Zs0ac7uOn45j99ZcV111lYlntGUlrbkspr3VvvCFL5gAR9uT/N///V/nbUceeaT5Qxnt6aefdkaNGmXur20tXnjhhQzsNfo6hnvuuaf5one/6D/OsOd7GI1g1s4xfOONN0xrQw2itE3XHXfcYVquwfvj19ra6tx6660mgC0qKnKqqqqciy++2Nm+fXuG9h6vvPJKzL9tkXHTax3H7tuMGzfOjLl+B3/1q19lZN99+h/388EAAABA8qiZBQAAgLUIZgEAAGAtglkAAABYi2AWAAAA1iKYBQAAgLUIZgEAAGAtglkAAABYi2AWAAAA1iKYBQAAgLUIZgHAQu3t7XLYYYfJKaec0mV9fX29VFVVyQ033JCxfQMAN3E6WwCw1EcffSTjxo2TRx55RM4880yzbvr06fLuu+/Km2++KX6/P9O7CABpRzALABb7+c9/Lrfeequ8//77smTJEvnmN79pAtmxY8dmetcAwBUEswBgMf0n/Oijj5b8/HxZtmyZXHbZZXLjjTdmercAwDUEswBgueXLl8vo0aPlwAMPlLffflsKCgoyvUsA4BomgAGA5ebOnSv9+vWTVatWyWeffZbp3QEAV5GZBQCLvfHGG3LkkUfKX//6V/nRj35k1r300kvi8/kyvWsA4AoyswBgqaamJjn77LPloosukqOOOkoee+wxMwlszpw5md41AHANmVkAsNTll18u8+bNM624tMxA/fKXv5Qrr7zSTAarrq7O9C4CQNoRzAKAhf72t7/JMcccI4sWLZLDDz+8y21Tp06VtrY2yg0A5ASCWQAAAFiLmlkAAABYi2AWAAAA1iKYBQAAgLUIZgEAAGAtglkAAABYi2AWAAAA1iKYBQAAgLUIZgEAAGAtglkAAABYi2AWAAAA1iKYBQAAgNjq/wOKf5OJgP6iCgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset created: 100 samples\n"
     ]
    }
   ],
   "source": [
    "# Create dataset\n",
    "X = jnp.linspace(0, 1, 100)[:, None]\n",
    "Y = 0.8 * X ** 2 + 0.1 + brainstate.random.normal(0, 0.1, size=X.shape)\n",
    "\n",
    "\n",
    "def dataset(batch_size):\n",
    "    \"\"\"Generate random batches from the dataset.\"\"\"\n",
    "    while True:\n",
    "        idx = brainstate.random.choice(len(X), size=batch_size)\n",
    "        yield X[idx], Y[idx]\n",
    "\n",
    "\n",
    "# Visualize dataset\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.scatter(X, Y, alpha=0.5, s=20)\n",
    "plt.xlabel('X')\n",
    "plt.ylabel('Y')\n",
    "plt.title('Training Dataset: $y = 0.8x^2 + 0.1 + noise$')\n",
    "plt.grid(alpha=0.3)\n",
    "plt.show()\n",
    "\n",
    "print(f\"Dataset created: {len(X)} samples\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a077a6cd13beb75",
   "metadata": {},
   "source": [
    "### Defining Training and Test Steps\n",
    "\n",
    "Now we'll define JIT-compiled training and test functions using JAX transformations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "7554f9042667b84a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:11.557478Z",
     "start_time": "2025-10-10T15:54:11.544113Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training and test functions defined and JIT-compiled\n"
     ]
    }
   ],
   "source": [
    "@jax.jit\n",
    "def train_step(params, counts, batch):\n",
    "    \"\"\"Single training step with gradient descent.\"\"\"\n",
    "    x, y = batch\n",
    "\n",
    "    def loss_fn(params):\n",
    "        # Merge graph and compute loss\n",
    "        model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "        y_pred = model(x)\n",
    "        new_counts = brainstate.graph.treefy_states(model, Count)\n",
    "        loss = jnp.mean((y - y_pred) ** 2)\n",
    "        return loss, new_counts\n",
    "\n",
    "    # Compute gradients\n",
    "    grad, counts = jax.grad(loss_fn, has_aux=True)(params)\n",
    "    \n",
    "    # Update parameters (simple SGD)\n",
    "    params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad)\n",
    "\n",
    "    return params, counts\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def test_step(params, counts, batch):\n",
    "    \"\"\"Evaluate model on test batch.\"\"\"\n",
    "    x, y = batch\n",
    "    model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "    y_pred = model(x)\n",
    "    loss = jnp.mean((y - y_pred) ** 2)\n",
    "    return {'loss': loss}\n",
    "\n",
    "\n",
    "print(\"Training and test functions defined and JIT-compiled\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2530077e5dba7459",
   "metadata": {},
   "source": [
    "### Training the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "c63154cbf63b9f43",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:21.259193Z",
     "start_time": "2025-10-10T15:54:11.579934Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training for 10000 steps...\n",
      "\n",
      "Step     0: loss = 0.168351\n",
      "Step  1000: loss = 0.011733\n",
      "Step  2000: loss = 0.011768\n",
      "Step  3000: loss = 0.011872\n",
      "Step  4000: loss = 0.011736\n",
      "Step  5000: loss = 0.011749\n",
      "Step  6000: loss = 0.011963\n",
      "Step  7000: loss = 0.011719\n",
      "Step  8000: loss = 0.011735\n",
      "Step  9000: loss = 0.011739\n",
      "\n",
      "Training complete!\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "total_steps = 10_000\n",
    "print(f\"Training for {total_steps} steps...\\n\")\n",
    "\n",
    "for step, batch in enumerate(dataset(32)):\n",
    "    params_, counts_ = train_step(params_, counts_, batch)\n",
    "\n",
    "    if step % 1000 == 0:\n",
    "        logs = test_step(params_, counts_, (X, Y))\n",
    "        print(f\"Step {step:5d}: loss = {logs['loss']:.6f}\")\n",
    "\n",
    "    if step >= total_steps - 1:\n",
    "        break\n",
    "\n",
    "print(\"\\nTraining complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "694aba2f960c2b4f",
   "metadata": {},
   "source": [
    "### Verifying the Trained Model\n",
    "\n",
    "Let's restore the model and check how many times it was called:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "5d495f172fcca5e0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:21.563382Z",
     "start_time": "2025-10-10T15:54:21.299626Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model called 10000 times during training\n",
      "Expected: 10000 times\n",
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1IAAAIoCAYAAABj6NoUAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAe4pJREFUeJzt3QeYFFX29/HfMCSRnCWJAQmCoigsoGJAcUUFw8qKIrIYAUUQAxgwAYqIGAh/dQ1rRtf0KosBs6AoZgVEETERBCTHmX6fU2VN9wyTeuhQ1fX9PE/b07dTdc9l7NPn3HOzIpFIRAAAAACAUitX+psCAAAAAAyBFAAAAADEiUAKAAAAAOJEIAUAAAAAcSKQAgAAAIA4EUgBAAAAQJwIpAAAAAAgTgRSAAAAABAnAikAAAAAiBOBFAAEwNtvv62srKy805IlS3z1eMU599xz857nyCOPlJ8sW7ZMAwYMUOPGjVW+fPm843zhhRfSfWiBYPMmdh7ZvMrk+QIAsQikAKCQwMJOJ598cqG3ffXVV3e6rX34CwP7YFvwtRd2Km1glsgPzc8+++xOx3HvvfcWeftIJKLTTz9dDz/8sH777Tfl5OTsdJvmzZvnPdYNN9wgP4h9z0p78suxA0AmKZ/uAwAAv3rllVe0ePFi7b333vnG77rrrrQdUxD985//VNu2bZ2fmzZtmrTneeihh3YasyBpyJAhhd5+6dKl+uCDD/Iun3jiiTr88MNVrly5vONF8WrXrq3bb7897/I+++wTmPkCALuKQAoAipCbm+tkNCZOnJg39t1332nmzJlpPS6/qFWrlkaNGlXkB2zP8ccf75ySXaJnmcKC5s2bp6+//rrQwOinn37Kd3nSpEkJCQQSYf369apWrVqJgYZn7NixWrNmjfOzBf4XX3xxvuu7dOlS5HNt3LhRu+22mxNAxqt69eoaMWKEEikV8wUAEiICAIi89dZbEfuT6J3KlSvnnNeoUSOyYcOGvNsNGTIk7zbZ2dl5P/fv33+nx/zll18iI0aMiLRt2zay++67RypVqhTZc889I2eddVbko48+KvQ4/vjjj8iFF14YqV+/fqRy5cqRDh06RJ566qmdju/HH3/Md7+cnJzIf/7zn8ixxx4bqVevXqRChQqRunXrRk444YTIK6+8UuLrLfh4RenWrVvefey1lIa9N9597P7moYceyvf8hZ3sGEtr/PjxeferWrVqpFGjRnmXL7/88p1uX9Jzxx5zUadYa9eujYwdOzbSsWPHSPXq1Z33v2nTps7jfP311zs9/+jRo/O9j/Z7HzRoUKRx48bO3Lvzzjsj8bDHKPgeF3W9Pfd7770XOeaYY5xjtbE1a9ZEtm/fHrn22msjf//73yN77723M/fLly8fqV27duSwww6L3H333ZFt27ble1ybN0X9zgq+xj///NP599CsWTPn/dlrr70iY8aMieTm5pY4Xwr7vdkceu211yJHHnmk8+/Lfu/HH398oe+3uf/++51/i/bvsEmTJs68sH/bBd8bACgtAikAKCSw6N27d97PkydPzvuwXK1aNWfsoIMOyvcBrGAg9c4770Rq1apV5Idw+7B8xx135LuPfZht1apVobfv2bNnkYHPpk2bIt27dy/2Q//w4cMzOpBq06ZN3v369u0bGTZsWN7lBg0aOEFCsgKp7777LtK8efMib2cf3KdPn57v+WODDAt4C/7ekxlIde7cOd+XAF4gtX79+hJfs82zHTt2xB1I1alTJ9K6detCH/O6664rUyDVtWvXSFZW1k6PZ8+1YsWKfPe7+uqrC31uC3xtfhBIASgLSvsAoBBnnXWW3n//ff3xxx9Oed+gQYOcNThWcmUuvfTSIhfw//nnnzr11FPzSq2sbMo6w1kZ1JNPPumUlFnZoJVEdejQQd26dXNud+2112rBggV5j2PjdrJ1PLZeqyjDhg3TG2+84fxcsWJFp/SrRYsW+uqrr/TMM884TRWsPNGeq2/fvgl7j9atW6cJEybsNG7rWvr06VPsfQ899FBnbc3TTz+tTz75pNCStNKW2c2dO1fffvtt3mV7/Q0aNNCdd97pXF6+fLn+97//6aSTTsq7jT33Dz/8oGnTpuWNWZmilSsaK52zU2zJ3LHHHqvjjjsu33Nbg4pTTjklr7lGvXr1nPfYShut1HD27NnaunWrzjnnHOf9L7jeztgcs1P37t3VtWtXrVy50jn+ZJkzZ46qVKmis88+2+lW+Nlnnyk7O9tpSmHH97e//c0Zt/di+/btzpy0ebRjxw5nnv33v//VGWecEddzrlq1ynkf7X1o1KiRHnjgAec1e2sObe7b3I2H/bto1aqV82/t888/14wZM/Ke69///reuvvpq5/LHH3+s2267Le9+9evXV//+/Z1/yw8++KC2bdsW1/MCQJ4yhV8AkGEKZmj+3//7f5FRo0blXZ45c2Zk3333dX620rktW7YUmZGybELsY82YMSPvuuXLlzslSN51vXr1csYtYxI7fsQRRzjlesZKn4477rhCM0irVq1yyq+88QcffDDf67JyMe86y6IlMiNV1KlgFqG4DENx15XWxRdfnPcYlgXcunWrM77PPvvkjZ966qk73a8070FJZV8vvvhi3vWW5bHslMcyN+3atcu73rJkhWVr7HTZZZeV6bWXJSNlxzlv3rwiH8vmqL2uKVOmRCZMmBC5/fbbnZI47/7/+te/4s5I2WnSpEl5173wwgv5rvvyyy/jzkhZ6eS6devyrrP5Xdjv20plYzPBsaV/BTOjZKQAxIP25wBQBMtC2d5CZuDAgfr++++dny+44AJVqlSp2G/8PZah+Pvf/57v2/DYy95t7Vv/DRs25I2feeaZeYv/LVNgGbLCfPTRR06mwPOvf/0rX9vrKVOm5F1n39pv2rRJmcSyPU899VTeZctOeJmN2KzYyy+/7GQqEi22659lp/bbb7+8997mjmUFPZadKoplZFLF5t/BBx+80/jmzZudzOkee+yhXr16OfPfsqZXXHGF07DD88svv8T9nJbxuvDCC/Mut2zZMt/1XtYvHv369cvXkMPe+8Iez8t4GssK7r///nmXLSvn/RsHgHgRSAFAEay86bTTTnN+/vXXX53zChUqOB8wi7N69eq8nwsr0Yod8z7wWTlgLAu4irpPUc9VEvtCP5HBxJ577uk8ZsFTIjZlLS3bODf2Q7OV9cUGox4r33r88ccT/vzxvP9WsleYunXrqk6dOkoVK4crzMiRI5128VZ2WlLwGi+bv5UrV867XPCLiJKeszC2x1es2MeMfbzYf1sNGzbMdx8Louz9B4Cy4GsYACjG0KFDnXU8HgusbI1HcWJbf9v6nIJix7w1OTVr1sx3mxUrVhR5n6Key1svVdzx1ahRQ5nEPvjHsnVMxd3W1rYlUuz7b4HCzTffHPd7v/vuuyuVinq+2Hnerl07Zz2fZY4s2LA1UbZOqqzsC4hYlrHbVaV9zNh/WwX/XVk211urBQDxIpACgGJ07tzZaYxgC9ZNaT6I254906dPz8tCWKMDr5zPPsjZ5djbelmCqlWr5pX32YdYKyG08j7L8hSVTenUqZNTNmVlZd6Hy8L29bFmCAsXLnQaXvhJ7IfheMsOf/vtN73++uulvr01Vfjyyy91wAEHJOz4Yvdn2rJli1M2Flu6GVuCWVw5qB/EZiuPOuqovBI4m8OpzDIm2iGHHOLsJ+aV+VmJ7r777utcfuyxx/KVxgJAPAikAKAE//nPf5w1TPah2gKrklhHMMtMeB9MLYtla5csiHniiSfygiX7Bv2yyy5zfrZv/a2jmbem6d1339XRRx+d17Vv1qxZRWZE7LHvv/9+5/L48eOdD4v2Ad8yJFaS+OGHHzpBhB1Xjx495LfySY992LUMoHX9s3VOJQWt9nvxAkhjXfmsG10sK/GKzaRY50Wvm19pj89bG2cZLevAaOtyrKOgdevr2bOnWrdurfnz5zu36d27t7NOq02bNs5zW2dA+11ap0Z77vbt28uvLPvkrYWy+WRBvL2fjz76aJFliUFg6xvvu+8+5wsJmy9HHHGE82/Nuk5adz8AKCsCKQAogWWLilpXUlQp0XPPPecs2Lf1GbaIf/LkyfluYx9SLejxWp+bW265xWkv/d133zmX33nnHedkjjzyyCKzApMmTdKPP/6Y1wL9zTffdE5BYIGHBZ0WdNjp7rvvzis/KymQeuSRR/J+tnbvL730UqG3sw/O7733nvOzBbLW+ry0DQYsKPJ+BxZM3HTTTc7PFkBZIGWPY+u0LEC1rJ+txYptfhEk11xzTd66MpuzNq+MNZ+wksl4sn9+Yhnlq666Srfeeqtz+ffff89rh25NN+zLBq901mvwAgClwV8MAEgC+/Bu3+5ffvnlTomUfbNvWZZmzZo5Hfisg5tdF8vWS9neVeeff77T7c9KwQ488EAnkzF69Ogin8se2/YssiDhhBNOcBb22wd8y55Y5uT00093vpG3vaT8xjI0VsZoH2hjmxGUxLJssXtuWbe5osReZ6WVxe3JVdDgwYOd/cJsf6Wigi/rFmclgxYYWybQfo9WbmmZKysjPO+88/T8888ndA+vZLBGHVaSanPOsq/WAMM6H9p7XdK6QL8bN26c82/A/i3av0MLDocMGeJkei0z5Sm4VhEAipNlPdCLvQUAAECAWYbNvlgoyNrix27UbGW0seveAKA4lPYBAICMNmrUKGcfNQua9tprL6fBhK0ljN1nzZpSlGYNJAB4CKQAAEBG8/Y3K2qdoXXxs6YkiWjLDiA8CKQAAEBGs6Ym1lDC2tBb0xBrVW/rodq2bes0DbF1bAU7PgJASVgjBQAAAABxomsfAAAAAMSJQAoAAAAA4hT6NVK2AeRvv/3m7PfBIlMAAAAgvCKRiNavX+/sn1fSJt2hD6QsiGratGm6DwMAAACAT/z8889q0qRJsbcJfSBlmSjvzapevbovMmTWUahevXolRsEA8wXxYs4gXswZxIs5gyDPmXXr1jlJFi9GKE7oAymvnM+CKL8EUtaW1Y4l3RMJ/sd8QbyYM4gXcwbxYs4gE+ZMaZb8+ONIAQAAACBACKQAAAAAIE4EUgAAAAAQp9CvkSptG8QdO3YoJycnJTWi27dvd+pE/VIjCv/Izs5W+fLladUPAACQZgRSJdi2bZt+//13bdq0KWVBmwVT1r+eD8soTJUqVbTHHnuoYsWK6T4UAACA0CKQKoYFND/++KOTBbBNueyDa7KDGy/7RdYBhc0NC+ytPajNyxYtWqT7kAAAAEKLQKoY9qHVginrJW9ZgFQgkEJxdtttN1WoUEE//fSTMz/JSgEAAKQHi3BKgbVK8BPmIwAAQPrxiQwAAAAA4kQgBQAAAABxIpBCqTVv3lyTJk0q9e3ffvttZ53Xn3/+qVR7+OGHVbNmzZQ/LwAAAMKBQCoDWfBS3OmGG24o0+N+/PHHuuCCC0p9+y5dujit42vUqKFMDBQBAAAQXnTty0AWvHiefvppXX/99Vq4cGHeWNWqVfN1CbSNhq1LYEnq1asX13FYR7mGDRvGdR8AAAAgCMhIpdCqVdKiRe55Mlnw4p0sG2RZKO/yggULVK1aNf3vf/9Thw4dVKlSJb3//vv64Ycf1KtXLzVo0MAJtA499FC98cYbxWZs7HEfeOABnXLKKU57eNvX6KWXXiqytM8rt3v11VfVunVr53mOP/74fIGftX6/9NJLndvVqVNHV111lfr376/evXsX+5rtsZs1a+Ychx3PqgJvckmv78gjj3Raig8bNiwvc2fscc4880w1btzYeex27drpySefLPPvBgAAAJmBQCoFNm+WHnlEGjlSGj3aPbfLNp4uV199tW699VbNnz9fBxxwgDZs2KATTjhBs2bN0meffeYEOCeddJKWLl1a7OPceOONOuOMM/Tll1869z/rrLO0evXqIm+/adMmTZgwQY8++qjeffdd5/FHjBiRd/1tt92mxx9/XA899JA++OADrVu3Ti+88EKxx/DRRx9p4MCBGjJkiD7//HMdddRRuuWWW/LdpqTX99xzz6lJkya66aabnMDOC+62bNniBJyvvPKKvv76a6e0sV+/fpo7d26p3mcAAICgf0mPwlHalwLTp0svvig1aCA1ayatXeteNv37p+eYLGA49thj8y7Xrl1bBx54YN7lm2++Wc8//7yTYbIApSjnnnuuk7ExY8eO1d133+0EGRaoFGb79u2aNm2a9tlnH+eyPbYdi+eee+7RyJEjnaySuffeezVjxoxiX8tdd93lPN+VV17pXN5vv/00e/ZszZw5M+829tqKe332+rOzs51sXWw5omWiYgO9Sy65xMmoTZ8+XR07diz2uAAAAJLBvoy3z5cffGBfFtuyDalrV+mMM6Tddkv30YUHGakks28IbJJbEGWnypWjP9t4ur5BOOSQQ3bK2FjAYCV3VlZn5W+WrSopI2XZLM/uu++u6tWra8WKFUXe3srjvCDK7LHHHnm3X7t2rZYvX54vQLHgxjJCxbHj7NSpU76xzp07J+T12foxC7qspM+CLbufBVIl3Q8AACDZX9JnZ7tf0tu5XbZxpA4ZqSSzKjf7psAmeSxrZGefxe36OnVSf1wW9MSyIOP11193yu723Xdf7bbbbjr99NO1bdu2Yh+nQoUK+S7b2qLc3Ny4bm8NL5KtrK/v9ttvdzJetjbMgil73y677LIS7wcAAJCKL+mNfVFvbPzEE9Pz2TKMCKSSrHZtN91q5XzeJDd22cbtej+w9UhWpueV1FkGZ8mSJSk9BmuMYc0grM36EUcckZcR+vTTT9W+ffsi72dZJlsnFevDDz+M+/VZl0F7voL3syYVZ599tnPZgsTvvvtObdq02cVXCwAAkDlf0ocRpX1JZhPZalaXL3dPW7ZEf7Zxv0x067hnDResWcMXX3yhvn37FptZShZbgzRu3Di9+OKLTsv2oUOHas2aNXld9ApjXf5sPZRlmxYtWuSsq4pdH1Xa12ddCa0Bxq+//qo//vgj736WybI1V1YKeOGFFzrlhwAAAOn+kj6W376kDwMCqRSwhX+9ell2xf2mwM7tso37xcSJE1WrVi1nE13rZtejRw8dfPDBKT8Oa3duzSvOOeccZ52TrUmyY6kcm84r4G9/+5vuv/9+pwTPGkq89tpruvbaa+N+fdb0wrJUtobL2zPLHsduZ7e3FunWiKKkVuwAAABh/5I+DLIiqVig4mPWXttKyqzRgTVKiGWtr3/88UfttddexX6Qj6em1dKt9k1BUZPcfh22l5JtkFtcFiYsLGtkpXvWYt2aPiD/vLRyRGvWUb9+fZUrx/ciKN2/KeYM4sGcQbyYM8mXaV37cn00Z4qLDQpijVQKWfDEtwTFs01xLaPUrVs3bd261SnTs6DBSvEAAADgBku2hY41lijpS3okD4EUfMW+hXj44YedLnuWnWvbtq3eeOMNJysFAACAKL6kTy8CKfhK06ZNnU55AAAAgJ9RuAoAAAAAcSKQAgAAAIAgB1K2h4+1pm7UqJHTse6FF14o8T5vv/220566UqVK2nfffZ31NQAAAAAQmkBq48aNzj5AkydPLtXtrZtbz549ddRRRzkbrV522WU677zz9Oqrryb9WAEAAACEl6+aTfz97393TqU1bdo0Zy+dO+64w7lsnd3ef/993Xnnnc4GqgAAAAASpzT7ooaFrwKpeM2ZM0fdu3fPN2YBlGWmimJ7E9kpdtMtbyMwO8Wyy9aC2zulivdcId8rGUXw5qM3Z72fgdJgziBezBnEizmTuZsAP/usNHt2dBPgLl2k00/f9U2A/TRn4jmGQAdSy5YtU4MGDfKN2WULjjZv3qzdCvmtjhs3TjfeeONO4ytXrtSWLVvyjW3fvt15M3fs2OGcUsEmUU5OjvOzrRPzs3feeUfHHnussxN1zZo1S3WfFi1a6JJLLtGll14qP7BA3MpJvaxmIo4v2a/R5qLNy1WrVik7O9vZedvmTbp3Akcw2NxhziAezBnEizmTmWbNkubNs8/a0t5725Ic93L58tIxx2TOnFm/fn04AqmyGDlypIYPH5532YIu27uoXr16ql69er7bWmBlb2b58uWdUypVqFBhl+4/YMAAPfLII7rgggucEshYgwcP1tSpU9W/f3899NBDZX4O+xBv4n1/7B9Iqt/PoliwaifveObOnavdd9+9VMdnjU2GDRumNWvW5BuP5zHKwh7X3sM6deqoYsWKzvHb/E33Hx4Eg/3PijmDeDBnEC/mTOaxUr5337XPfvY5zs1OlStnX+6647aixkr9MmHOVK5cudS39cen2TJq2LChli9fnm/MLltAVFg2ylh3PzsVZL+0gr84u+x90E5Vdsgice+5dvU5LUB8+umnNWnSpLz3w4LDJ598Us2aNdvl54g9zngeJ9Hv57Zt25yAoqxij6d+/fpx3S/23BPPY+zK8XpzNvZnoDSYM4gXcwbxYs5kFvvO2BI1f318zFO9urR0qXt93bqZMWfief5Az+7OnTtrluUZY7z++uvOOOS0hbdg6rnnnssbs58tiDrooIPy3dbWjVkpmgUBFokfdthh+vjjj/PdZsaMGdpvv/2coMw6JS5ZsmSn57RmH4cffrhzG3tue0zrxlha5557rnr37u2UX3pZwosuusgJljxHHnmkhgwZ4qyFq1u3bl5jka+//tppVlK1alWnxLNfv376448/8u5nx3HOOec41++xxx555Xyxmjdv7gSenj///FMXXnih83j2vrRt21Yvv/yy03bfsn6WhvYCmxtuuKHQx1i6dKl69erlPK+9njPOOCPfFwB2v/bt2+vRRx917lujRg3985//jCu1DAAAkCyWbbI1UWvX5h9fu9Yd35VsVJD5KpDasGGD08bcTl57c/vZPoh6ZXn2QdhjH7AXL16sK6+8UgsWLNCUKVM0ffp0p9wqqQ45RGrSJDmnpk1Vfq+9nPOdrrPnjdO//vWvfOV7Dz74oBMAFGTv4X//+1+nHPDTTz919uSyAGW15XIl/fzzzzr11FOdfb7sd2Jt5q+++up8j/HDDz/o+OOP12mnnaYvv/zSyYZZYGVBTzwsOJ4/f74TrFj2zIK/guva7DgtC/XBBx84pYsW8Bx99NFOgPjJJ59o5syZTrBiQYvniiuucNZ1vfjii3rttdecx7fXWlya2QIze47HHntM3377rW699VanpLFLly5OsGSB0e+//+6cRowYUehjWBBl76M9twX6Nmf79Omz03tn+6ZZkGYnu609FwAAQLpZd76uXa3yyz1ZW4Hlf/1s46Ht3hfxkbfeesva1O106t+/v3O9nXfr1m2n+7Rv3z5SsWLFyN577x156KGH4nrOtWvXOs9h5wVt3rw58u233zrn+TRubP30Un+y5y0le6969eoVWbFiRaRSpUqRJUuWOKfKlStHVq5c6Vznva8bNmyIVKhQIfL444/n3X/btm2RRo0aRcaPH+9cHjlyZKRNmzb5nuOqq65y3rs1a9Y4lwcOHBi54IIL8t3mvffei5QrVy7vPdxzzz0jd955Z7HHXbt27cjGjRvzxqZOnRqpWrVqJCcnx7lsc+Cggw7Kd7+bb745ctxxx+Ub+/nnn53jW7hwYWT9+vXOHJk+fXre9atWrYrstttukaFDh+aNxR7fq6++6hy73b8wNtdq1Kix03jsY7z22muR7OzsyNKlS/Ou/+abb5zjmjt3rnN59OjRkSpVqkTWrVuXd5srrrgi0qlTp0KfN3Ze2nvy+++/5703QEmYM4gXcwbxYs5kpk2bIpGHH45Ezj8/EjnzTPfcLtt4Js2Z4mKDgny1RspKtopr+W2L+wu7z2effaaUatgwaQ8d++qzEvC8Vh5nmxbbe2fvrf1s5XAFsyHWobCrfaUQ0+yiY8eOTmbI2HmnTp3y3a9gCeUXX3zhZKIef/zx6Ov5q5WlZRdtn6/SsC56VapUyfc8lq20rNiee+7pjHXo0GGn537rrbec8rmC7PVZF0crD4x9DbVr11bLli2LPA7LvDVp0sQpZywre9+sxNFOnjZt2jhdDu26Qw891Bmzkr5q1arl3cZKD60bIgAAgB/Ycvv+/aUTT2QfKY+vAqnA+OST5D12JOK0t3Y6viWoIYOV93nldZMnT1ayWLBj64kKa/vtNbdIFOuKV/C5rezwtttu2+m2FpR8//33cT9HUQ1LUtGl0dZc+WEvBQAAgFgWPIU9gPLlGikkh61bsmyMZZ28xgyx9tlnn7z1Rh67rTWbsOyJsWyStfWO9eGHH+7U3MLWEdn6qoKneLrqWXbJMkixz2OZptisTkH23N98842T2Sn43BZ02Wu0YOWjjz7Ku4+1Lf/uu++KfMwDDjhAv/zyS5G3sdfk7flVFHvfLJNmJ4+9R7amy3tvAQAAEDwEUiFgzRGsjMw+wHt7P8WyQOPiiy92mjFYkwa73fnnn69NmzZp4MCBeY09Fi1a5Nxm4cKFeuKJJ3Yqtbzqqqs0e/ZsJ/tlZXF2e2vsEG+zCQv67HntOKxT4OjRo53HKK4dpe2NZQ0dzjzzTCcAtHK+V1991WmsYcGOBWL2mHb8b775ptPhzzoEFveY3bp10xFHHOE0z7AmEVae+L///c95j4wFbZYJs+YY1h3Q3q/CNvxt166dzjrrLKexhQWj1jDFHvuQMjQPAQAAgD8QSIWEdZcruOFwLOsQZwGDtQy37I6VwlkgUqtWrbzSPOvqZ53lbA2TdcobO3bsThkc6zZnGRxrgW4d9K6//no1atQormM95phj1KJFCyeIse52J598cl5r8aLYc1hGzYKm4447zglerD26rUXygqXbb7/dOS4rAbQAx1q8F1xrVZC9ZlvHZAGaZZCsu6GXhbLOfRZg2jHaWrTx48fvdH8r0bNg0t5Hez32vHvvvbfT0RAAAADBlWUdJxRi69atc/btsf2ACgYatnmtZSH22muvuHY53hWRmDVSqdoE2E8sS2RlbxawoXCx89LKC60phe3/le4N7BAMtvaOOYN4MGcQL+YMgjxniosNCmJ2AwAAAECcCKQAAAAAIE60P4evFLZXGAAAAOA3ZKQAAAAAIE4EUqUQ8n4c8BnmIwAAQPoRSBXDNnA1he0PBKSLNx+9+QkAAIDUY41UMWzzWtuHyNoxmipVqiS9JXnY25+j+LlhQZTNR5uXNj+tXSgAAABSj0CqBA0bNnTOvWAqFR+W7cOx9dAnkEJhLIjy5iUAAADSg0CqBBbM7LHHHs4GYdu3b0/681kQtWrVKtWpUyftG5LBf6yczzJRAAAASC8CqVKyD6+p+ABrgZR9WK5cuTKBFAAAAOBTfFIHAAAAgDgRSAEAAABAnAikAAAAACBOBFIAAAAAECcCKQAAAACIE4EUAAAAAMSJQAoAAAAA4kQgBQAAAABxIpACAAAAgDgRSAEAAABAnAikAAAAAKTN6tXSb7+550FSPt0HAAAAACB8Nm+Wpk+XZs+WqlWT1q+XunSRzjhD2m03+R4ZKQAAAAApN3269OKLUna2VK+ee26XbTwICKQAAAAApNSqVdIHH0gNGkj160sVKrjndtnG7Xq/I5ACAAAAkFKrV0sbNkg1auQft8s2HoT1UgRSAAAAAFKqdm2palVp7dr843bZxu16vyOQAgAAAJBSdepIXbtKy5dLK1ZI27e753bZxu16v6NrHwAAAICUO+MM99y69q1cKeXkSL16Rcf9jkAKAAAAQMrttpvUv7/Us6e0dEmumjWX6tZVYFDaBwAAACA9tm1T7QfGq92/jlft3bcqSAikAAAAAKTeu+9KBx2kciNHqsJXX0l33KEgIZACAAAAkDorV0oDBkjduknffusMRcqVU5b1PQ8QAikAAAAAyZebK91/v9SypfTww3nDkY4dtWrmTEXGjlWQEEgBAAAAabRqlbRokXuesb78UjrsMOmCC6Q1a9yxmjWlqVMVef997WjXTkFD1z4AAAAghgU0q1e7m8Imcz+jzZul6dOlDz6QrKrNNqK1PZSs/bd1tMsI69dLN9wg3XWX29/cc/bZ0oQJUoMGbqYqgAikAAAAgDQENvZcL77oxhLNmklr17qXjbUFD7RIRHruOWnoUOnXX6PjVtY3dap01FEKOkr7AAAAgJjAJjvbDWzs3C7beDKyXhawWRBlp8qVoz/beKDL/BYvlk48UTr99GgQZS/wllukL77IiCDKEEgBAAAg9FId2FjpoGW9atTIP26XbdyuD5ytW6UxY6T995dmzIiO//3v0jffSNdcI1WqpExBIAUAAIDQS3VgY+uvrHTQyvli2WUbt+sD5a23pPbtpWuvlbZscccaN5aefVZ65RVp772VaQikAAAAEHqpDmysiYWtv1q+3D1Z7OH9bOPJbHKRUMuXS/36SUcfLS1Y4I5ZTeSwYdL8+dJpp0lZWcpEBFIAAAAIvXQENtbEolcvt5nd0qXuuV22cd/LzZWmTZNatZIeeyw6/re/SfPmSRMnStWqKZPRtQ8AAAD4K7AxtibKAhvLRCUzsLFOgNadz/oypKLdesJ89pl00UXS3LnRsVq1pFtvlc47TyoXjlwNgRQAAACQxsDGniMQAdS6ddL110v33JN/7yd708aPl+rXV5gQSAEAAABBDGxSuSfUM8+4655++y063rq1uydUt24Ko3Dk3QAAAJBxrCX5okX+2XPJb8eTED/84LYv79MnGkRZb/ixY6XPPw9tEGXISAEAACBQNm92N8m1tUzWmtzWMllDCFvLZOV5YT+ehO0JZeV6ti+U/ezp2dMt7dtrL4UdGSkAAAAEigUtL77odtlu1sw9t8s2zvEkwKxZ0gEHuOuhvCCqSRPpueek//f/CKL+QiAFAACAwLCyOcv8NGjgnqzKzPvZxlNdVue349kly5ZJZ50lde8uffedO2ZR4eWXu3tCnXJKxu4JVRYEUgAAAAgM66Zn5XM1auQft8s2bteH+XjKxDawmjzZ3RPqiSei4507S59+Kk2Y4NYrIh8CKQAAAASGtSS3z/Rr1+Yft8s2bteH+XjiZpvn2ia6Q4ZEX4TtCXXffdL777slfigUgRQAAAACw9qSWyOH5cvd05Yt0Z9tPNVty/12PKVmQdOll0odO0qffBIdP/dcaeFC6fzzQ7OxblnRtQ8AAACBYt3wjK1BWrrUzfz06hUdD/vxlLgnlHXBuOwyd02Up00bd0+oI45I59EFCoEUAAAAAsVaivfvL514orsGycrn0pn58dvxFMk2uRo8WHr99fwHP3q0u9luxYrpPLrAIZACAABAIFmw4qeAxW/Hk8fqDW+7TRo3Lv+eUCedJN19t9S8eTqPLrAIpAAAAIBMZdmnQYOk77+PjjVt6m6qa/WHKDNWkAEAAACZ5vffpTPPlI47LhpElS8vXXmluycUQdQuIyMFAAAAZArbE8qaRlxzjbRuXXT8sMPc8bZt03l0GYVACgAAAMgE1sb8oovcvaE8tmjr9tvdbhi0M08o3k0AAAAgyP780+3GZ3tCxQZRAwe6e0INGEAQlQRkpAAAAIAgsj2hnnxSGj7c3QHY066dW8ZnOwIjaXwXmk6ePFnNmzdX5cqV1alTJ82dO7fY20+aNEktW7bUbrvtpqZNm2rYsGHaYi0eAQAAgExlmaZjj5XOOisaRO2+uzRhgpuVSkMQtWqVu1WVncdzXVD5KiP19NNPa/jw4Zo2bZoTRFmQ1KNHDy1cuFD169ff6fZPPPGErr76aj344IPq0qWLvvvuO5177rnKysrSxIkT0/IaAAAAgKTZvNndD8r2hdq2LTp+yinSXXe5rc3TcEjTp0sffCBt2CBVrerGcWec4V5f1HW2F3CQ+SqQsuDn/PPP1wCr45ScgOqVV15xAiULmAqaPXu2unbtqr59+zqXLZN15pln6qOPPkr5sQMAAABJNXOmNGSI9MMP0bE995TuvVc68cS0Hdb06dKLL0oNGkjNmklr17qXPUVdZ/0vgsw3gdS2bds0b948jRw5Mm+sXLly6t69u+bMmVPofSwL9dhjjznlfx07dtTixYs1Y8YM9evXr8jn2bp1q3PyrPurLWRubq5zSjc7hkgk4otjgf8xXxAv5gzixZxBvJgzSfDrr8oaPlxZzz6bNxSxPaFGjFDE2pxXqWJvfFoObfVqS25IDRtKXgFZ5cpSVpb05pvu5cKus/v07CnVru2vORPPMfgmkPrjjz+Uk5OjBhauxrDLCxYsKPQ+lomy+x122GHOm79jxw5ddNFFGjVqVJHPM27cON144407ja9cudIXa6vsl7d27Vrn9VggCRSH+YJ4MWcQL+YM4sWcSaAdO1TloYdUdfx4ZVld3F+2/e1vWnvrrcpp2dKtl4u5LtV++02qVk2qV0+qUCE6bmV7S5e6P1smquB1K1e61+/Y4a85s379+uAFUmXx9ttva+zYsZoyZYqzpur777/X0KFDdfPNN+u6664r9D6W8bJ1WLEZKWtSUa9ePVWvXl3pZhPJ1njZ8aR7IsH/mC+IF3MG8WLOIF7MmQT56CNlDRqkrM8/zxuK1K2ryPjxKn/OOapjaR0fKF/egg9p06Zo1smsWCFt3Oj+/NNPO19n+wZbgOVlpPwyZ6zhXeACqbp16yo7O1vLY1s3ypqQLFdDywcWwoIlK+M777zznMvt2rXTxo0bdcEFF+iaa64p9BdRqVIl51SQ3TbdvziPTSQ/HQ/8jfmCeDFnEC/mDOLFnNkFa9ZIVl31f//ntjf3XHCBssaNU5ZFHrvIOudZSZ49lO3Xuyvq1rXlNu66JzvcGjXcdVD2kb5XL/c2RV1n9/XbnInn+X0TSFWsWFEdOnTQrFmz1Lt3b2fMolO7PMQW1RVi06ZNO71YC8aMpQYBAACAQLDPro8/Ll1+uZuy8RxwgHVgkzp3Tmp3vV3poHfGX9357HGtXM8e1wIlb7yk64LKN4GUsZK7/v3765BDDnGaR1j7c8sweV38zjnnHDVu3NhZ52ROOukkp9PfQQcdlFfaZ1kqG/cCKgAAAMDXrB/AoEHSW29FxyziuOkm6ZJL3Pq5JHfX25UOervt5t7fGgcWlukq7rog81Ug1adPH6fpw/XXX69ly5apffv2mjlzZl4DiqVLl+bLQF177bVOGtDOf/31V6eu0oKoMWPGpPFVAAAAAKVMEdnn1vHjpe3bo+OnnSZNmiQ1aZKwp7JyPssK2cdqr7ebtxzIxi3Q2dUAp06doh+juOuCyleBlLEyvqJK+ay5RKzy5ctr9OjRzgkAAAAIjBkz3D2hfvwxOrbXXu6eUCeckPCns2yQlfNZJiqWrVuykju7PtMCnWRjBSAAAACQKr/8Ip1+uruJkhdEWW9wazDx9ddJCaKMldRZtaCV88WyyzaegB4WoUMgBQAAACSbbZh0551S69bSf/8bHT/ySOmLL9wSP9tYN0ks22SNJaxjnp1s+1TvZxsnG5UBpX0AAABARpkzR7r4Yjdg8tgOtnfcIZ19tvX+TslhlKa7HkqPQAoAAABIBlt4NHKkdN990TELmi64QLIu1LVqpfRwSuquh/gQSAEAAACJ3hPqP/+RrrhCWrkyOt6+vbsnVKdO6Ty6jOyglw6skQIAAAAS5dtvpaOOks49NxpEWQ2drY/6+OO0B1FIHDJSAAAAwK7atEm65Rbp9tvdxhKef/zDDaIaN07n0SEJCKQAAACAXfHyy+6eUD/9FB3bZx93T6jjj0/nkSGJKO0DAAAAysJa351yinTSSdEgqmJF6brrpK++SkgQtWqVtGiRew5/ISMFAAAAxGP7dumuu6QbbpA2boyOH3OMNHmy1LLlLj/F5s3S9Oluq/ING9xlVrbfk7Uqt+57SD8yUgAAAEBpWWRz8MFuRz4viGrQQHr8cen11xMSRBkLol58UcrOlpo1c8/tso3DHwikAAAAgJJYbd1550mHHSZ9/XV0T6jBg6UFC6S+fRO2sa49lcVrFp/ZqXLl6M82nuwyP8oJS4fSPgAAAKAoubnSww9LV16ZP7KwrJTtCXXooQl/Stss18r5LBMVq0YNd1mWXZ+MfaAoJ4wPGSkAAACgMJZ56tZNGjgwGkRVry7dfbc0d25SgihTu7YbxKxdm3/cLtu4XZ8MlBPGh0AKAAAAiGVrn666SjroIOn996Pj//ynW8Z3ySVulJEklm2yTNDy5e5py5bozzaejGxUussJg4hACgAAAPC89JLUpo00fnx0Y91995Vee0168klpjz1SchhWTterl5ST45bz2bldtvFk8MoJrXwwll22cbse+bFGCgAAALB9oC691A2kPLYn1KhRbnbKUjQpZGuS+veXTjzRDWKsnC8ZmajCygljX2qyywmDjEAKAAAA4d4T6s47pRtvlDZtio4fe6y7J1SLFuk8Oid4SmYAFfs8VjZoa6K8TJQFUVZOaJmwVBxD0BBIAQAAIFRsvY9leeovfE81rr5Y+uab6JUNG0qTJrk1dIW0M/fum+wMUTp4ZYO2JsrKCS0TlcxywqAjkAIAAEAoeO29v5j1h05+/0q1+PGh6JXlyrl7Qt18884LhRLQGjwIAViqywmDjkAKAAAAoTD9qVxtvOdB3fTtVaq6Ndo94Y+9DlHdZ6ZJHTqU2BrcuthZa3Are/PK4Cz4yKS9mVJVThh0dO0DAABAxvvzva/UdeThGvTZ+XlB1JZK1fXkYZN1zdEfalXzDklpDc7eTJmLQAoAAACZy9JAI0aoxlEHad/ls/OGv2zXV/cOWagvug7S+k3Zxbb3Lmtr8HTuzWSPvWgR+z8lE6V9AAAAyDyRiPTCC25L819+kdc2YnmN/TTz5Cn6ce9jnMtrl5fc3rusrcG9AMwyUQUDMGvmYNcnuoQuiKWEQUVGCgAAAJllyRLp5JOlU091gihHpUr67JSbdEm3L/Xh7sdoyxa3tbedLNAoLqDxWoN7ty/tfWMDsFjJ3JuJUsLUIZACAABAQlgZ2fffS+vWpekAtm2Txo2T2rSRXn45Ot6jh/T112r1+HXqeWol5eS4GSE7L217b7uN3Tae+5Y1ACurdJYShhGlfQAAAEhYOdnGjdJee7n72Ka0nOydd6SLL5bmz4+O7bGHuyfUP/7h7Allh1LW9t5lbQ2eyr2Z0lFKGGYEUgAAANglsa3BmzZ1t2R66aWSW4MnxIoV0hVXSP/5T3TMDuCSS6SbbpKqV09oe+9475vKvZnKupYLZUNpHwAAABJaTlazZgrKyXJzpfvuk1q1yh9EdewoffKJm4kqJIhKFwueLEuXzIxQqksJw45ACgAAAGVWVGtwi2GKaw2+S774wo0MLrxQWrPGHbPobepUafZs6aCDFFZlWcuFsqG0DwAAAAkvJ7OGEwkvJ1u/Xho9Wrr7bjdC8PTrJ91+u5sGC7lUlhKGHYEUAAAAdrmczNZIGctMWdWdlZNZB/KEfIi3PaH++1/pssukX3+NjltZ35Qp0lFHJeBJMsuurAND6VDaBwAAgISVk/38sxtIWRCVkHKyH36QTjjB7bznBVGW+hozxi3xI4hCmpCRAgAAQMLKybzmEvvu6zbPK7OtW91yPQuYrGuCx4Kqe+6R9t57l48b2BUEUgAAAEgIKyWrVcvtSL5L3nxTGjRIWrgwOta4sbs26pRTnD2hgHSjtA8AAAD+YAurzj5bOuaYaBCVnS0NG+ZutHvqqQRR8A0yUgAAAEgvW1xle0KNHOm2//P87W/StGnSgQem8+iAQhFIAQAAIH0+/VS6+GJp7tzomNUH3nabNHDgLi60ApKHmQkAAIDUs42mhg6VDj00fxBlXSsWLJDOP58gCr5GRgoAAACpY3tCPfOMuyfU779Hx1u3lqZOlbp1S+fRAaVGmA8AAIDU+P576fjjpT59okGU9U4fN076/HOCKAQKGSkAAACUmu0TtXq1VLu22+681HtC2ZqnsWPdnz228dQ992hVteZa/VOcjwmkGYEUAAAASrR5szR9uvTBB9KGDVLVqlLXrtIZZ7hJpSK98Ya7J9SiRdGxJk2cAGrzcb00/Zms+B8zWQEfEAcCKQAAAJQYgFgQ9eKLUoMGUrNmbpdyu+z1h9jJsmXSiBHSk09Gx7w9oUaPdqKm6Y/E+ZjJDPiAOBFIAQAAoNgAZNMmd9wCHjuZypXdcxu3Cr28oCsnR1UeekhZVsoXuydUly5uM4kDDsgL2Er9mHGIO+ADyohmEwAAACFjQYxV2tl5wQDEkkYWgNi5XbZxy1BZcFWjRv7Hscs2btc75s1TVpcuqj5qlLK8IMpSWw88IL33Xl4QZUr9mHG+rtjgzAIz72cbj329wK4iIwUAABDyrNMxxxSfHbJkkt3WYiNv3NhlG69Tfq106XXS5MnKys2N3mDAAGn8eKlu3Z2OxeKr4h7Tro+XF5xZIFgwOFu61L2e9VJIFDJSAAAAIVFU1unxx4vPDhkLuJYvd09btvz187KIzsp+SrW7tHKaR+ivIGp7y5bKfftt6cEHCw2ijAU0hT7mcne8LAFPbHAWa1eCM6AoZKQAAABCoLg1Sd98I5UvX3x2yNZKGXsMy+7snbNI1y0apEYvvBG9Q5Uqyr3+eq3q21f1Gzcu8ZgKPqY9V69e0fF4ecGZtybKAkF7DRac2eOSjUIiEUgBAACEQHFlbxZstG8vffhh8QGINWs4sfsW6dZbVfv+W5UVuyeU3fCuu6SmTaUVK0p1TNZFz3nMExPXqjzRwRlQFAIpAACAEChpTVLfvtGmDEUGIK+/rjq2J9T330fHLDKzsr6TT3Yvx66RKiULnhKVLUpGcAYUhkAKAAAgBEoqe7M9cosMQH77TRo+XHr66egDWi3g5ZdL110n7b67/CaRwRlQGAIpAACAkChN2Vu+ACQnx+nEp2uvldavj97o8MPdPaH23z+1LwDwEQIpAACAkIir7O3jj6WLLpI+/TQ6Zh34br/dfZCsrFQdNuBLBFIAAAAhU2zZ259/SqNGSdOmSZFIdPy885wmE9TLAS4CKQAAALhB0xNPuGuhYrvutWvnBlW2Ky+APGzICwAAEHYLF0rdu0tnnx0NoqyBxIQJ0rx5BFFAIchIAQCAUG5OS2tsSZs3S2PHSuPHS9u2RcdPOSW6JxSAQhFIAQCAUMUN06e7Xetsc1rrWmctwa1rnTViCJWZM6XBg6XFi6NjzZtL994r9eyZziMDAoHSPgAAEBoWRNk+StnZ7j6ydm6XbTw0fv1V+sc/pL//PRpEVajgNpj45huCKKCUCKQAAEBoyvksE9WggXuqXDn6s43b9Rltxw5p0iSpVSvp2Wej4926SV98IY0ZI1Wpks4jBAKFQAoAAISCrYmycr4aNfKP22Ubt+sz1ocfSoceKg0b5r5YU6+e9Mgj0ltvSa1bp/sIgcAhkAIAAKFgjSVsTdTatfnH7bKN2/UZZ80ad1Nd67r3+efR8QsvlBYskM45h411gTIikAIAAKFg3fmsscTy5e5py5bozzaeUd37bE+oRx+VWraU/u//ohvrHnigNGeOuy9URkaOQOrQtQ8AAISGdecztiZq6VI3E9WrV3Q8I8yfLw0aJL39dnTMXuhNN0mXXCKV5+MfkJEZqcmTJ6t58+aqXLmyOnXqpLlz5xZ7+z///FODBw/WHnvsoUqVKmm//fbTjBkzUna8AAAgOKzFef/+0rhx0o03uud2OSNan2/a5Hbes6xTbBB12mlucGXrowiigITx1b+mp59+WsOHD9e0adOcIGrSpEnq0aOHFi5cqPr16+90+23btunYY491rnv22WfVuHFj/fTTT6pZs2Zajh8AAASDlfFlVCnfK69IQ4ZIS5ZEx/bay90T6oQTFGRsngy/8lUgNXHiRJ1//vkaMGCAc9kCqldeeUUPPvigrr766p1ub+OrV6/W7NmzVcH2P3D2kWue8uMGAABIi19+kYYOlZ57Ljpmn4muvNLNTgW4nTmbJ8PvfBNIWXZp3rx5GjlyZN5YuXLl1L17d82xRZGFeOmll9S5c2entO/FF19UvXr11LdvX1111VXKth32CrF161bn5Fm3bp1znpub65zSzY4hEon44ljgf8wXxIs5g3gxZ3y8J9Q99yjr+tHK2rQxbzhy5JGKTJ7s7hVl0vB7S9ScsSDqpZfcfb5s82T7yGaXTb9+iTlW+EOuj/7OxHMMvgmk/vjjD+Xk5KiB/WuJYZcXWHvOQixevFhvvvmmzjrrLGdd1Pfff69BgwZp+/btGj16dKH3GTdunG60ougCVq5cqS3WvscHv7y1a9c6k8kCSaA4zBfEizmDeDFn/KfCJ5+o2hVXqeKCb/PGNuxeV1/0G60mV52mSpWzpBUrAj1nLGhatEjaf3/JW7Fhqzzq1nXHv/9eql49sceN9Mn10d+Z9evXBy+QKuubbuuj7rvvPicD1aFDB/3666+6/fbbiwykLONl67BiM1JNmzZ1slnVffAv0l5TVlaWczzpnkjwP+YL4sWcQbyYMz6yerWyRo5U1gMP5A3lKktz21+gZw4eoyUraunkd9KfrUnEnLFA6scfpaZN3Z8927ZJP//s/lzI8nkEVK6P/s5Yw7vABVJ169Z1gqHltplDDLvcsGHDQu9jnfpsbVRsGV/r1q21bNkyp1SwYsWKO93HOvvZqSD7paX7F+exieSn44G/MV8QL+YM4sWcSTPbA+o//5FGjLASnrzhpXUP0qu9purXJp1UzQKLiu56ohNPTH9Thl2dM3b8u+/ubpYc+7nWLtu4Xc90zCxZPvk7E8/z+2YKWtBjGaVZs2bli07tsq2DKkzXrl2dcr7YWsbvvvvOCbAKC6IAAAAC5dtvpSOPlM49Ny+Iytm9mh45+C5NGzDXCaI8NWq4TRmsw13QhWrzZASWbwIpYyV3999/vx555BHNnz9fF198sTZu3JjXxe+cc87J14zCrreufUOHDnUCKOvwN3bsWKf5BAAAQKD3hLLPPLYn1LvvRsfPOENrP1ygDzpcqjXr8xcWWbbGOttZm/BMYN35bLPknBx382Q7z7jNkxFovintM3369HGaPlx//fVOeV779u01c+bMvAYUS5cuzZdus7VNr776qoYNG6YDDjjA2UfKgirr2gcAABBI/+//SZdcIv30U3Rsn32kKVOk446TxUmWlXnxxWgmyoIoy9ZYoOHXbE28+0F5mydbqSL7SMGPsiLWHiPErNlEjRo1nE4hfmk2sWLFCqeJRrprROF/zBfEizmDeDFnUsjSLpdeGo2QjC1VsL007RSzeZKf91gqOGf8fKzwh1wf/Z2JJzbwVUYKAIBMFu838giJ7dulSZOkG25wS/o83btLtifUfvsFOltjQZTFht5+UJY982JFew1AUBFIAQCQZHwjjyK9/74t+pa+/jo6Zt2KJ06U/vlPa2VW7N0tePJrAOV9eWDz3oIob6tQrwufXzoMAmVFjh4AgBR9I2+7ddg38nZul20cIWUd+AYOlA4/PBpEWdA0ZIi0YIF05pklBlFBYNky+/LA1nHFyqQOgwgvAikAAFL4jbx9G+/9bON2PZLD3ttFi3z2HtuWLf/+t9SqlfTgg9HxDh2kuXOle+7ZOeoIMCs5tAyslfNlcodBhBOlfQAApOAbectExbLPytZbwK6ntCkkpZRffeWW8dmBeWwx+9ix0kUXuanKDOPtBxW0DoNAaZCRAgAgifhGPvV8V0pp0dyVV0oHH5w/iLLyPSvjs/0vMzCI8rAfFDIVGSkAAJKIb+RD3tzAfvG2J9TPP0fHWrRw94SyrnwhEKQOg0A8yEgBAJBkfCMfwuYGtpnuySdLvXtHg6hKldwW519+GZogKpYFTxZDEkQhU5CRAgAgyfhGPj2llF4mKqWllNu2SXfeKd14o7tYy3Pcce6eUPvum+QDAJAqBFIAAKSI3/f8yQRpLaV87z23acS330bH9tjD3Wz3H//IiHbmAKIo7QMAABkl5aWUK1dKAwZIRxwRDaLKlZMuvdRtJmFPTBAFZBwyUgAAIKOkrJTS9oSyvaCsI9+aNdHxQw+Vpk1zu/QByFgEUgAAICMltZTSGkZYGd+cOdExqyO0PaEuvDCj25kDcFHaBwAAUFrW+m/ECDfbFBtEnXWWW8Y3aBBBFBASZKQAAABKEolIzz8vDR0q/fJLdHy//dw9oY45Jp1HByANyEgBAAAU58cfpZNOkk47LRpE2Z5QN9/slvgRRAGhREYKAACgqD2hJkyQbrkl/55Qxx8v3XuvtM8+6Tw6AGlGIAUAAFDQ22+7653mz4+ONWok3XWXm5minTkQepT2AQAAeFascHunH3VUNIiyPaEuu8y9fPrpZQ6iVq2SFi1yzwEEHxkpAAAA2xPq/vulq6+W/vwzOt6pk7snVPv2ZX5oqwqcPl364AO36V/VqlLXru4+vbbnFYBgIiMFAADC7fPPpS5d3H2hvCCqZk03gJo9e5eCKGNB1Isvul3RmzVzz+2yjQMILgIpAAAQTuvXS8OGSR06SB99lDe8pc85Wvy/hVp1+oVuWd8usDI+y0Q1aOCeKleO/mzjlPkBwUVpHwAACN+eUM8+6657+u23vOHclq30Wu+pem71kdpwd2JK8Favdsv5LBMVq0YNaelS9/o6dXbx9QBICzJSAAAgPH74QTrhBDc68oIoSxONHavHr/hC9313ZEJL8GrXdgOytWvzj9tlG7frAQQTgRQAAMh8W7e6+0G1bSvNnBkdt6Dq22+16oKReu+jigkvwbNsk2W1li93T1u2RH+2cbJRQHARSAEAEEKhasX95pvSgQdK113nRjKmSRPpv/+VXn5Z2muvvBI8K7mLZZdt3K4vK0t+9eol5eS45Xx2bpdtHEBwsUYKAIAQCVUrbkv7XH659Pjj0TGr1xs6VLrhBqlatUJL8CwblcgSPHtfbWuqE090AzJ7LDJRQPCRkQIApEyosiA+FYpW3JbymTpVatkyfxDVubP06afSHXfkC6JSVYJnj9GiBUEUkCnISAEAki5UWRAfK9iK23jZFxu3jEngP+RboGT7QX38cXSsVi3pttukgQOLbWfuldrZe2EleDZPKcEDUBQCKQBAyrIg9uHdsiBWLmWXjZU8ITUyuhX3unXuGqh775Vyc6PjNsFuv12qV6/Eh6AED0A8CKQAAEkViixIQCRzHVBa94SySN021v399+h4mzZued8RR8T9kDYfmZMASsIaKQBAUiWzGxpC3or7+++l44+X/vnPaBBlaaVbb5U++6xMQRQAlBYZKQBAUmVkFiTAMmIdkO0JZWuexo51f/acdJJ0991S8+bpPDoAIVHqQKpz5866//771dY2sgMAIM4siLcmyjJRFkRZFsQ+wAcuCxJwgV8H9MYb0qBBbvtHT9Om0j33uBMKAPxW2rdkyRJ16NBBo0aN0hZvMzsAAEqBDUn9J3CtuJctk/r2lY49NhpElS8vXXGF9O23BFEA/JuRWrhwoUaOHKnx48frmWee0dSpU9W9e/fkHh0AICMEPguC9LGoe9o0adQotzOfx9Kc1kyiXbt0Hh2AECt1Rqp69eqaPHmy5syZ4/zco0cP9evXTytXrkzuEQIAMkbgsiBIr08+kTp1koYMiQZRFoX/+9/Su+8SRAEIVrOJQw89VB9//LHuueceXXfddXr55ZfV1GqTC8jKytIXX3yRqOMEAABhYYvorrlGmjLFbW/u+de/3CYTdeum8+gAoOxd+3bs2OFkorZu3ao6deo4JwAAgF1iQdNTT0nDh7trojzW6MrK+A47LJ1HBwC7Fki98cYbGjRokBYvXuycjxkzRtWqVYv3YQAAAKK++04aPNjtyuepUkW64QbpssukChXSeXQAUPZAyjJQw4YN05NPPql27dpp9uzZ6tixY2nvDgAAsDPrBDxunLuJ7rZt0fHevaW77pKaNUvn0QHArgdSLVu21LZt23Trrbdq+PDhys7OLu1dAQBABlm1KkHdF197zc1Cff99dGzPPd09oWxzXQDIhEDqb3/7m6ZMmaLm7BYOAEAobd4sTZ8uffCBtGGDVLWq24Xc9gOzFvel9ttv0rBh7oN5bE+oESOka6+Vdt89GYcPAOkJpGbMmJHYZwYAAIFicc+LL0oNGrgVd9Zczy4b2yesRDt2SJMnS9ddJ61fHx0/4gi3Q9/++yft2AEgbftIAQCA8LJSPstEWRBlp8qVoz/buJX7FWvuXMnWVlvjCC+IsjbmDz8svf02QRSAwCGQAgAApQqkrJyvRo3843bZxu36Qq1ZI118sa0RkD77LDp+/vnSwoVuKisrK6nHDgC+2UcKAACEizWWsDVRVs5n2SiPXbZxu36nPaEef1y6/HJpxYro+AEHuHtCdemSsmMHgGQgIwUAAEpkgZI1lli+3D1Z13LvZxvP171vwQLpmGOkfv2iQZQ1kLjjDmnePIIoABmBjBQAACgV685nbE3U0qVuJqpXr+i409ZvzBhp/Hhp+/boHU89VZo0SWraVJkuYa3hAfgegRQAACgVa3FuS5pOPLGQYMG6+w4ZIv34Y/QOtmXKvfdKPXsq0yWsNTyAwKC0DwAAxMWCpxYt/gqifvlFOv10N1jygqgKFaRRo6RvvglFEBXbGj47220Nb+d2OXarLACZhUAKAADEz/aEuvNOqXVr6b//jY536yZ98YVb4lelisLAyvl2qTU8gEAikAIAAPH58EPpkEOk4cPdOjZTr570n/9Ib73lBlchUubW8AACjUAKAACUjkUEF17odt2zrJOxPaBszDr1WZe+Mu4JZVmbRYuCmb2JbQ0fq8jW8AAyAs0mAABA8WxPKMs2XXmltHJldPzAA6Vp09zNdkPcpMHWitkx25ooLxNlQZS1hreuhnTvAzITGSkAAFC0+fNV+7TTVG7AgGgQZdHOxInSJ5/sUhCVSU0aLPCzoCknx20Nb+f5WsMDyDhkpAAAwM42bZJuuUVZEyaoYuyeUNahz/aEatw44U0ajDVqMDZubdaDks0ptjU8gIxEIAUAAPJ75RV3T6glS+SteIrsvbeybE+ov/894U0aLBMVy0rjLKtj1wctGLHj9csxszkwkFwEUgAAwPXzz9LQodLzz+cNRSpU0MbBg1XFslO77560Jg1eJsrQpGHXZMK6MyAIWCMFAEDYWeneHXe4bctjgigddZQin3+uDVddlZRP4F6TBmvKYKctW6I/2zhZlHCvOwP8jkAKAIAwmz3b3RNqxAhp40Z3rH596bHHpFmzpFatkvr0NGlILDYHBlKH0j4AADJUsWtk7Mqrr5YeeCA6ZntAXXSRNGaMVKtWtPV5EtGkIbEycd0Z4FcEUgAAhGmNTOWI9Mgj0hVXSH/8Eb3TwQdLU6dKHTsq7E0agox1Z0DqUNoHAEBI1si8duc3Urduku0J5QVR1apJd98tzZ2btiAKicO6MyB1yEgBAJBBCtubqVq5jer94U3q/sJEKbIjeuM+fdyNdRs1StvxIvG89WU2D6yczzJRrDsDQpKRmjx5spo3b67KlSurU6dOmmvfkpXCU089paysLPXu3TvpxwgASG1wsGgRC+XjWSNja2JMy4UvafCUNurx+Xhle0HUvvtKr71m/+MkiMrA+e+tOxs3TrrxRvfcLtP6HMjwjNTTTz+t4cOHa9q0aU4QNWnSJPXo0UMLFy5UfesiVIQlS5ZoxIgROvzww1N6vACA5GE/nLKvkcn+5Sf985NL1WrhS3nXbS9XUdsvH6kqN12dfwENMnL+s+4MCFlGauLEiTr//PM1YMAAtWnTxgmoqlSpogcffLDI++Tk5Oiss87SjTfeqL333julxwsASJ4w7YeTqKxbnerbdcHa8bruqTb5gqgv6nXXy2O/UpXxNxBEBUSY5j8QRL7KSG3btk3z5s3TyJEj88bKlSun7t27a86cOUXe76abbnKyVQMHDtR7771X7HNs3brVOXnWrVvnnOfm5jqndLNjiEQivjgW+B/zBZk8Z6xEzbY4atjQ3dbI2Od/69Bt4z17ZkYHMss6PPus+5q8rEOXLtLpp5ch6/bee8oaPFiHfPNN3tCayg31bJc7VLFfH53+j6y4f/d+nzM2T7y26amaD6l4ziDPf7/PGfhPro/mTDzH4KtA6o8//nCySw281bF/scsLFiwo9D7vv/++/v3vf+vzzz8v1XOMGzfOyVwVtHLlSm2x1jY++OWtXbvWmUwWRALFYb4gk+fMb7+5DeXq1ZMqVIiOW3CxcqW7iH5HTN+EoLI9b+fNcxtDWFGF7Ylrl8uXl445pnSPkbVqlardcouq2Jqnv0TKldOffc/VT+ddpW57VFf16iu1fr2cUybMGftO9P33pfnz3c50FmS0bi0ddphUqVLwnzPI89+vcwb+leujObM+jj+SvgqkyvJC+/Xrp/vvv19169Yt1X0s22VrsGIzUk2bNlW9evVUvXp1+WEiWcMMO550TyT4H/MFmTxnLJCw/59t2hT9Rt6sWGEl3W6pk1+/kY8n6/Duu27Jlv06LDtl5/YB2cZ79CjhNdo3pw89pKyrr1aWPdhfIoccosiUKarRoYMOyNA58+ij0ksvuQGo/e/bXv5zz7nvXb9+wX/OIM9/v84Z+Feuj+aMNbsLZCBlwVB2draW22YHMexyQ8ttF/DDDz84TSZOOumkndJx5cuXdxpU7LPPPvnuU6lSJedUkP3S0v2L89hE8tPxwN+YL8jUOWPfj1mJm60JiUTcLnS2qaj9L8JaOZfy+zNfW7PG/bBsH4pj2Yd0yzjY9UW+zq++ki66yK3zir3j2LHKuugiZVl0lqFzxmvxbgFGbNmbzRMbP/HExDdZSPVzBn3++23OwP+yfDJn4nl+X83uihUrqkOHDppldQ4xgZFd7ty58063b9Wqlb766iunrM87nXzyyTrqqKOcny3TBAAIbity605mHxrtG3gLLOw8k/bD8Trs2QfkWHbZxgvNONhCqiuukA46KH8Q1bevtHChNHiwm+LKYAVbvHvsso3HJOcC/ZyZPv+BoPNVRspY2V3//v11yCGHqGPHjk77840bNzpd/Mw555yjxo0bO2udLPXWtm3bfPevWbOmc15wHAAQvFbk3n449m2/t7g/k9o522ux99CyDqZg1iHfa7W0hN3w0kuln3+OjrdoIU2ZInXvnrDjsiDZe79r1ZKvA9DYKpxiA9AAPmemz38g6HwXSPXp08dp/HD99ddr2bJlat++vWbOnJnXgGLp0qVpT/kBQJhbMdufYytFsw+QXgBgH/aSKZP3w/GyCxagWtbBPpTvlHVYskS65BLp5ZejY1amPmqUVg28Uqs3VVbtVbv+HhUVLB95pIIbgAb4OcMw/4Egy4pYe4wQs2YTNWrUcDqF+KXZxIoVK5x27gSMKAnzBamaM5ahsJ0prGIstrGqfYi0cqNx4/igl8gsUN57uW2bbbBo+3y4UY7nuOO05Y7JenrevgnNED7ySDRY9gKFFStydeqpK3Tmmf76O5OODCkbRJcO/29CkOdMPLGB7zJSAAD/8daHFGyKYB+2LYti1wctkCo0cEmjnbIO77wjDRokffttdGyPPaRJk6R//ENP/ycroRlCr5mCPZ4XLHv7Flm7b3uv/NTgIB1lb5TaAYhFIAUA8OX6kGTxfVbBNgmyZhKWHvLYN7TWROKWW5zOfEUFPaasHeSKCpbtC1nbM8lvgVQ6y94otQNgyLcCAEq9PsRK+exkH6y9n208SB8qvbVeVqZoQYOd22UbTyvbvuO++6SWLfMHUYceKn38sXT33W5Uk6QOckV1EFy3zg3SghQsA0AqEEgBAELTirlgJscCBO9nG09FS/dCffGFdNhh0oUXuptHeVGRdeObM0c6+OBdb5u+C8Fy69YEUgBQEKV9AIDQrA8p7VqvlK2fst14R492s00WmXrOOkuaMEEqZDP6ZHaQK6yD4MknuzEeACA/AikAQGjWh5S01suCRauqS/r6KWuY+9xz0tCh0q+/Rsf320+aOlU6+ujEtE1PQLBs+0itWFH2xwSATEUgBQAItHiyRyVlcmbNSsFeWYsXu3tCzZgRHbOo7ppr3CYTtj9UmjOEscGyLd0CAOyMQAoAEEhl7b5XVCbnmGPc7ZoS2QkvH9sTysr1br7ZXYDkOf546d57pX32CV2GEACCjEAKABBIXve9eLNHRWVyFi1K4l5Zb73l7gm1YEF0rFEj6a67pNNOczdrAgAECl37AACBk4juexYUtWgRDY6S0QnPqRns189d8+QFUbYn1GWXuZdPP50gCgACikAKABA4ydhHKaF7ZdnComnTpFatpMcei4536iTNmyfdeadUrVr8BwkA8A1K+wAAGdd9r6x7HiWkE95nn0kXXSTNnRsdq1lTuu026bzz3IwU4pKydvQAEAcCKQBA4CRrH6Vd6oS3bp10/fXSPffkb3V3zjnS7bdL9euX7aBCrKwNRQAgFQikAACBlIx9lMrUCc/2hHr2WXfd02+/Rcdbt3b3hOrWbdcPKKTK2lAEAFKBQApA4FH2E07J3Eep1H74QRo8WHr11fwHdt110uWXSxUrpviAMrehSMLb0fsMf8eA4CGQAhBYlP0gbfsobd0qjR8vjRnj/uyxT/d33y3ttVeKDyhzG4okpR29j/B3DAguVrwCCHzZT3a2+2HLzu2yjQNJM2uWdMAB7nooL4hq0kR6/nnppZcIokrIuth+XaVpT5+UdvQ+xN8xILjISAEIpLCV/cAHli1zy/WeeCI6Zp96hw2TRo92P90jYVmXZDUU8RP+jgHBRkYKQCAlYx8hoFA5OdLkye6eULFBVJcu0qefuh35CKKSknWxQMuCJvsVWDmfnSeqoYgf8HcMCDYyUgACKVn7CPkVC9HTxDbPtT2hPvkkOma/BFsfNWAAe0IlOevii4YiSRS2v2NApiGQAhBIYSj7MSxETxObTNZ5zzJRsXtCWfBkG+vWq5fOowtd04i0NBRJgbD8HQMyFYEUEGJBz3Ikcx8hv2AfnRSzPaHsTbc9oWxNlGf//d09oQ4/PJ1HF0hkXYoXhr9jQKYikAJCKFOyHJle9sNC9BSzdnK2J9Trr0fHqlRxG0lYQ4kKFdJ5dIFF1iXcf8eATEYgBYRQpmU5MrXsJyz76KTdli1uud64cfn3hDr5ZHdPqD33TOfRZQSyLuH9OwZkMgIppFTQS8kyAVmO4KAkKgUs+zRokPT999Exi1wtgLJP+kgIsi4AMhGBFFIiU0rJMgFZjuCgJCqJfv9dGj5ceuqp6Fj58u6YbbS7++7pPLqMRdYFQCahbytSgp3b/ZnliEWWw58yfR+dlLM38N573T2hYoOoww6TPvvMLfEjiAIAlAIZKSQdpWT+QpYjWCiJSiDbC8r2hLK9oTz2ZtqGuvYmsydUQlDCDSAsCKSQdJSS+Q8Lv4OHkqhd8Oef0rXXSlOmuO3NPQMHuhko3tiEoIQbQNgQSCHpWDDvP2Q5EAoWND35pLvuyVKunrZtpWnT3E/5SJhM6wYKACWhjgEpKyWzzzF2sk7D3s82zgf49LH3vkULfgfIQAsXSsceK511VjSIsj2hxo+XPv2UICrJJdz2pZn3s43b9QCQaQikkBIsmAeQsvoy67p3wAHSrFnR8d69pfnzpSuuYGPdJJZwW8l2LLts43Y9AGQaSvuQEpSSAUi6mTOlwYOlxYvzhnKa7qnl19yjSqefxN+cJKKEG0AYEUghpVgwDyDhfv1VGjZMeuaZvKFI+fL6uscI/V+9a7X6nd1VdR6ND5KJbqAAwohACgACLrTtpnfscPeEuu46t37Mc8QReqnHFD3yyf5qUFlq1oDGB6lAN1AAYUMgBQABFep20x995O4J9fnn0bG6daUJE7Sq5zl6ZVQWe9elGCXcAMKGZhMAEPB209nZbrtpO7fLNp6x1qxxA6jOnfMHUeef73bq699fq9dk0fggjegGCiAsCKQAIIBC127a9oR69FGpZUvp//4vurGudeebPVu67768jgaxjQ9i0fgAAJBIBFIAEEChaje9YIF09NHSOedIK1e6YxYRTZwozZvnZqdisHcdACAVWCMFAAFs4BCKdtObNkljxki33y5t3x4dP+00adIkqUmTIu9K4wMAQLIRSAFAABs4ZHy76RkzpCFDpB9/jI7ttZfbpe+EE0q8O40PAADJRmkfAAS0gYMFdxY05eS4WRc7D3zW5ZdfpNNPl3r2jAZRFSpI11wjff11qYKoWDQ+AAAkCxkpACiigYPf22ZnVNbF9oS65x7p+uvz7wl15JHSlClS69YKM7+XmgJAGBFIAUCBBg6WiYplZXOW8bHr/fgh1o7Jj8dVanPmSBdfLH3xRXSsXj3pjjuks8+WsrIUVkEpNQWAMKK0DwD+QtvsFLPI9MILpS5dokGUBU22T5TtCdWvX6iDqKCVmgJA2BBIAcBfaJudIrYH1COPSK1aufs/edq3d7NTU6dKtWop7EK3VxgABAyBFABkegMHP/n2W+moo6Rzz82/J5S1M//4Y6lTp3QfoW+Eaq8wAAgg1kgBSKigL4rPqAYOftsT6uabpQkT3MYSnn/8Q7rzTqlxY2WKRP0bCMVeYQAQYARSABIi0xbFB76Bg5+8/LJ0ySXSkiXRsb33liZPlo4/Xpki0f8GMn6vMAAIOEr7ACQEi+KxE6uNPOUU6aSTokFUxYrStde6e0IFNIiyjNOiRTuvUUrGvwFKTQHAv8hIAQjd/ktIsu3bpbvukm64Qdq4MTp+9NHunlAtWyrTMk5WuZiMfwOUmgKAf5GRArDLWBSPPLNnSx06SFdcEQ2iLLJ4/HHpjTcCG0SVlHFK9r8BC55atCCIAgA/IZACsMvYfwlOWvK889wUzVdfuWO2B9SgQdKCBVLfvoHeE6qkVuSGfwMAEC4EUgB2GfsvhXxPqIcecjNN//53dPzgg6WPPnIbStSsqaArKeNk+DcAAOHCGikACeEtfrdv521RvH0Lz6L4DGcNIy6+WHr//ehYtWrSmDFuJspq3zJEaVqR828AAMKFQApAQrAoPkRs7dNNN0kTJ+bfE6pPH3esUSNlmtK2IuffAACEB4EUgIRi/6UM99JL7p5QlnLx7Luv243v2GOVyUqbceLfAACEA4EUAKBkP/0kXXqpG0h5KlWSRo6Urroqf71bhiLrCgCIRSAFACh+Tygr17NSPtssyWPZJ2skYT25Q4aMEwDAEEgBAAr37rtuM4lvv42O7bGHdOedbj1bgNuZAwCwq2h/DgDIb+VKacAAqVu3aBBVrpy7Nmr+fLepBEEUACDkyEgBAFy5udKDD7prnmwRkOeQQ6Rp06QOHdJ5dAAA+AqBFABA+vJL6aKLpDlzomPVq0tjx7rjGbQnFAAAiUBpHwCE2YYN0ogR0sEH5w+i+vaVFi6UBg8miAIAoBBkpAAgjCIR6fnn3Zbmv/wSHd9vP3dPqGOOSefRAQDge77MSE2ePFnNmzdX5cqV1alTJ82dO7fI295///06/PDDVatWLefUvXv3Ym8PAGGXvXSpsk4+WTr11GgQZXtCWYtzK/EjiAIAIHiB1NNPP63hw4dr9OjR+vTTT3XggQeqR48eWrFiRaG3f/vtt3XmmWfqrbfe0pw5c9S0aVMdd9xx+vXXX1N+7ADga9u2SbfeqrrduilrxozoeI8e0tdfS9dd5wZUAAAgeIHUxIkTdf7552vAgAFq06aNpk2bpipVquhB6yRViMcff1yDBg1S+/bt1apVKz3wwAPKzc3VrFmzUn7sAOBb77wjtW+vctdco6wtW6J7Qj39tPS//0n77pvuIwQAIFB8tUZq27ZtmjdvnkaOHJk3Vq5cOadcz7JNpbFp0yZt375dtWvXLvT6rVu3OifPunXrnHMLvuyUbnYMkUjEF8cC/2O+oEQrVijryiuV9eijeUORcuUUsSYSVspnnflsvZSdgELwdwbxYs4gyHMmnmPwVSD1xx9/KCcnRw0aNMg3bpcXLFhQqse46qqr1KhRIyf4Ksy4ceN044037jS+cuVKbfG+pU3zL2/t2rXOZLIgEigO8wVFys3Vbo8/rmpjxyrrzz/zhrcddJB+ueYaVe7cWeXsb54P/u7B3/g7g3gxZxDkObN+/fpgBlK76tZbb9VTTz3lrJuyRhWFsWyXrcGKzUjZuqp69eqpun0z64OJlJWV5RxPuicS/I/5gkJ98YWyBg1S1ocf5g1FatZUZMwYlRs4ULutXs2cQanxdwbxYs4gyHOmqBjC94FU3bp1lZ2dreXLl+cbt8sNGzYs9r4TJkxwAqk33nhDBxxwQJG3q1SpknMqyH5p6f7FeWwi+el44G/MF+Sxb9FGj5buvlvKyYmOn322siZMUJZl+//6nxVzBvFgziBezBkEdc7E8/y+mt0VK1ZUhw4d8jWK8BpHdO7cucj7jR8/XjfffLNmzpypQw45JEVHi2RYtUpatMg9B1BKtr7pv/+VWreW7rwzGkS1bCnZ31NbH1WgZBoAAOwaX2WkjJXd9e/f3wmIOnbsqEmTJmnjxo1OFz9zzjnnqHHjxs5aJ3Pbbbfp+uuv1xNPPOHsPbVs2TJnvGrVqs4JwbB5szR9uvTBB9KGDfb7k7p2lc44Q9ptt3QfnT9YcLl6tWR9VOrUSffRwDcWL5YuuUSKbWduZQnXXiuNGEE7cwAAwhJI9enTx2n8YMGRBUXW1twyTV4DiqVLl+ZLuU2dOtXp9nf66afnexzbh+qGG25I+fGjbCyIevFF90vzZs2ktWvdy6Z/f4VacUEmn5FDzLqPTpgg3XJL/oYRf/+7dO+90t57p/PoAADIeL4LpMyQIUOcU2GskUSsJUuWpOiokMxMiwUJFkR51UfeOj8bP/HEcGdgigsy+/VL99EhLd56S7r4YmnhwuhY48bSXXdJp55qhebpPDoAAELBV2ukEE5WrmaZlho18o/bZRu368OqYJBpAab3s42H+b0JJWvEY9Hz0UdHg6jsbGnYMGn+fOm00wiiAABIEQIppL0phK35sXI1y7TEsss2XsTeyqF4Xwky4bDNAadNk1q1kh57LDr+t79J8+ZJEydK1aql8wgBAAgdX5b2IVxNIaxszx7DK1ezIMGCKPvyvVevcJT1FfW+HnNMNMiM3dYgNsjcsSOdR46k++wz6aKLpLlzo2O1almnHWngQOvTms6jAwAgtPg/sM9YhuG33/ybafDW61g1ka3XsXO7bOO7wgIxC5qsa/PSpe65XbbxMCjqfbXO1RZQWVBpJ+sp4P1s42HK1oXOunXS0KGSbekQG0RZ9xUr6zv/fIIoAADSiIyUzzISs2e7FTq2r2aXLv5q/53MphD2Gu3zoT1GJrf4LqyFeUnv6/XXR3+2INMyUekOMmnFnuQ9oZ55RrrsMun336PjbdpYm1LpiCPSeXQAAOAvBFI+y0g0bCjVqydt2uS/9t/eeh3LmMSyUjz7gG/X7+qHart/Kj+YpyogKK4ksqT31e5bVJBpS2dSif2+kuz7761tqfTqq9Exe2Mtmh4+3HYtT+fRAQCAGARSPhCbkahfX6pQwT23L6b91P47tilEUet1giLVAUFxLczt91ua9zXVQWZh2O8riXtC2ZqnsWPdnz02Oe65R2rePJ1HBwAACkGBvQ8EpTOb1xSiqPU66f6Q74e1XmVpYW6C8L6W9DoS0cExlN54Q2rXznYRjwZRTZtKzz8vvfQSQRQAAD5FIOUDQWr/nQlNIVIdEJQmUA7C+xqUgD8wli2T+vaVjj3W7Xlvy6Oys7Vm4Aiteu9bqXdv9oQCAMDHKO3zgdj23/a5yUrLVqzwZ/vvTGgKkYq1XvGWRAbhfc2k0s60sijZ9oQaNcrtzPeX5S26atoBU7VwUztVHcPaMwAA/I5Ayie8zIN17Vu50p8ZiVh+WK/j14CgYAOLePbJ8vP7yn5fCWCb59qeUJ98Eh2rXVsf9BqviWsGqH69cmr21/vK2jMAAPyNQMonvIxEz55uVsSyJXXrpvuoMlOyAoLiGlh4AbGfWpiXRaa8jpSzCXbttdLkyW4XGc+//qXVV92mRybUVf2Gid9WAAAAJA+BlM9YFmPHDsqkghgQlNTRzu+le6URhBJEX7Gg6amn3NbltibKs//+bnnfYYdp1aLUlpoCAIDEIJBCKCU6ICjtZsV+Lt2LR6a8jqSyBhKDBrld+TxVqmjjFTfotzMuU+0GFWRvIWvPAAAIJgIppH2z2kwICFLdwAI+ZjWetifUuHHStm15wzkn9tLzR96l1xbuqQ235C/9ZO0ZAADBQyCFtG9WmwnIKsCJfKZOlaZMcTvGeCy6vucePbbm5CJLP1l7BgBA8BBIIe61PtgZHe1CnOn86ivpzjulxx/Pl4FS+fLu2qjrr9eqLbvrg5HFl36y9gwAgGAhkEKZ1vpgZ2QVQpLpjES06pfN2jLzbdV7/E5VfCdmDZTJzpZOP93t0te2rTO0+rfSlX6y9gwAgOAgkEI+rPUpOzraBSzTuX27tGZN9GS/tMLOY36OrFqt3FVrVGfH1p0fz/6RXHCBNGTITv+A/Fr6mXHZQQAAUohACoH4wBckZBVSmOnMzZXWrSs2+ClyzL4xiFOWJZwKjP22+776+ZSh6jT1XPcfSQBKP7dulR59NAOygwAApBGBFHz9gQ8hyHT+FNGaXzerzpY1+nPxaq3/eY1q5KxW9R1/BTzFZYr+/NMNplIgsvvuWqPa2lS5lrZXq60NVRvq6/3/qfdqnKgdkWztu1WqU3gc5bvSz/ffl156Sapf3yfZQQAAAohACr7+wIcAiS2VKxD0NP5ltf711RpVnbdGtbRau21eo902r1bFTWtUZctqVXjKbdJQ869T0lSs6KZVa9VyT97PxY3Zec2a+v6niho92g08YrO11beUruzVL6Wf9tzz57uZQQukDOsgAQCIH4FUhkjkWge/fOBDGhRVKlea0rliSuWqSOqeqGPMynICm3zBjxfweKeiAiKb3Hb/NJa9prv0035dW7ZI1asnbx0ka68AAGFAIBVwyeyElu4PfCijSMSdGGVZN5TCUrktFapqR7Vaqtiwtn5cU8spmYvUqKXNu9V2Tiu219L6CrV15qBaqr5nTLBkn/jLlVOqZUrZq72FFgjarzvR6yAzpjMjAAClQCAVcL7rhIbEl8oVE/xkrVqlmsuXK2vjxvzXxe5nlEwVKhSeBYrNEBUYXx2ppVW5tVS7YUUn+Fi0SLqxkJI5y5pYhuToA6TqLeQLQSx7LZgdsvPWraXnnnNj7kQGhPw9AgCECYFUgLHnU0BL5YrLEMVeV4quclakFhN7lI2VuhUskSvNuiE7r1Il7lI5S3jUDminyCCVvRaVHbItrg47TNqxI7EBIX+PAABhQyAVYOz5lCL2tf2mTXHtN5R3nsJSubyoo6Tgp+CYLZZJQ6lckEvmglD2Wlx2qEcPqV+/xAaE/D0CAIQNgVSABembfF+wcreCAU9pA6J0lcqVUC6XW7Om/sjJUd0WLVSuUiUFVRBL5vysuOzQ7NlS585ux75EBoT8PQIAhA2BlN8yHwVPxahTW+ra5a9vmQtb62AfXIp/iGCxzI595W0v0k5WMuf9bKfYIKiwrJGtI0qF2K5yhWWDCgZEu1Iql5ur3BUr3AAswIJUMhcExWWHfv65THsRZ2RmEQCAXUEg5Sd77qlyP/+shnHcxdZvF7qG+wVJAxJ3aKG0++5ucGOfAIsrkSs4nuZSuSALQslcEJSUHbJTMpBZBACECYEUMpttwFrSeqHCAiLLKNl9gQAqKTtUcA+pRCGzCAAIEwIpP+nQQZEmTbR9+3ZVqFDB6ciGAuwrbvtUGHuyT4V2XlgjhV3YgBWZL5M3ji0qO2Rd+9avT+5zk1kEAIQBgZSfPP+8Irm5Wr1iherXr68sysNQSt5SsEwMCJIhDBvHFpUdsqWGyQ6kAAAIAwIpIOABwaxZ0rvvuh+OMzEgSIYwbRxLdggAgOQg5RFyVtq0aJF7juB59lnpo4+k7Gw3ILBzCwgsUEDpWoNbMwbvZxvn3wIAACgNMlIhFYbSpkxnH/htTyALALwqUK9Dm/1eraSLTMTO2DgWAAAkAhmpkJc2hSGTkalZNy8gsC7tBQMCG7frUXxr8FhsHAsAAOJBRiqECpY2ZWomI9Ozbl5AYPsMx/YlISAoHhvHAgCARCAjFUJeJsM+QGZyJiPTs272gb9LF+nPP6UVK6QtW9xgwE4WKBAQFM2CaQuacnLccj47Z+NYAAAQDzJSIS9t8jJRmZbJCEvWzfYEKl/e7doXu1cQAUHx2DgWAADsKgKpEApDaVNYGgpYQHDMMVKPHuwjVRa0BveXTN4gGQCQeQikQsrLWFh2JhMzGWHIusWy11O3brqPAiibTF/PCADITARSIZXppU1hyLqVFt/yw+/CtEEyACBzEEiFXCaXNmV61q0kfMuPIAjLekYAQOYhkELGyvSsW0n4lh9BEJb1jACAzEP7c2Q8+xDWokW4PowV/JbfvuH3frbxTNucGMHFBskAgKAikAIyUFj2CkPmrGf09kBjPzQAQFBQ2hcgNA1AaYWtayGCLezrGQEAwUQgFQA0DUC86FqIIAn7ekYAQDARSAUATQNQFnzLj6DJ5C6iAIDMQyDlc7QGRlnxLT92BaXEAAAUj0DK52gNjF3Ft/yIB6XEAACUDl37fI7WwADSUUqcne1+gWPndtnGAQBAFIGUz9EaGECqsP8YAAClRyAVAFZSY00CcnLccj47p2kAgERj/zEAAEqPNVIBQNOA9GCxPcKG/ccAACg9AqkACUrTgKAHICy2R1ix/xgAAKVHIIWEyZQAhH27EGbsPwYAQOkQSCFhMiEAYd8uhB2lxAAAlA7NJpAQmdLti8X2gMuCpxYtCKIAACgKgRQSIlMCEPbtAgAAQGkQSCEhMiUAYd8uAAAAlAaBFBIikwIQ9u0CAABASWg2gYTJlG5fLLYHAABASQikkDCZFoAEZd8uAAAApJ4vS/smT56s5s2bq3LlyurUqZPmzp1b7O2feeYZtWrVyrl9u3btNGPGjJQdK8LZ7cu6EC5aFJxuhAAAAMjwQOrpp5/W8OHDNXr0aH366ac68MAD1aNHD61YsaLQ28+ePVtnnnmmBg4cqM8++0y9e/d2Tl9//XXKjx3h2HT4kUekkSOl0aPdc7ts4wAAAAgP3wVSEydO1Pnnn68BAwaoTZs2mjZtmqpUqaIHH3yw0NvfddddOv7443XFFVeodevWuvnmm3XwwQfr3nvvTfmxIzybDmdnu5sO27ldtnEAAACEh6/WSG3btk3z5s3TSPua/y/lypVT9+7dNWfOnELvY+OWwYplGawXXnih0Ntv3brVOXnWrVvnnOfm5jqndLNjiEQivjgW5GfrvmbPlho2lOrXd8ds4+GsLHe8Z8/Ut3lnviBezBnEizmDeDFnEOQ5E88x+CqQ+uOPP5STk6MGDRrkG7fLCxYsKPQ+y5YtK/T2Nl6YcePG6cYbb9xpfOXKldpiPbt98Mtbu3atM5ksiIR//PabVK2aVK+eVKFC/iYbK1e6nQp37EjtMTFfEC/mDOLFnEG8mDMI8pxZv359MAOpVLBsV2wGyzJSTZs2Vb169VS9enX5YSJlZWU5x5PuiYT8ype3f1zSpk3RjJSx5Xu215SV+qUjI8V8QTyYM4gXcwbxYs4gyHPGmtcFMpCqW7eusrOztdx2cY1hlxtaPVUhbDye21eqVMk5FWS/tHT/4jw2kfx0PHDVrSt16eKuiYpEpBo1pLVr3U2Hbb8suz4dmC+IF3MG8WLOIF7MGQR1zsTz/L6a3RUrVlSHDh00a9asfBGqXe7cuXOh97Hx2Nub119/vcjbA7vCNhe2oMkyUFbKZ+dB3HQYAAAAu8ZXGSljZXf9+/fXIYccoo4dO2rSpEnauHGj08XPnHPOOWrcuLGz1skMHTpU3bp10x133KGePXvqqaee0ieffKL77rsvza8EmSjTNh0GAABAhgRSffr0cRo/XH/99U7DiPbt22vmzJl5DSWWLl2aL+XWpUsXPfHEE7r22ms1atQotWjRwunY17Zt2zS+CqSabYybysDGnoMACgAAILyyItYeI8Ss2USNGjWcTiF+aTZhmw/Xr18/7TWiQWAb4doeTh98IG3YIFWtKnXt6pbaWfYo0zFfEC/mDOLFnEG8mDMI8pyJJzZgdiPQ2CAXAAAA6UAghUCX81kmyqo+7WTdKr2fbdyuB2weLFrEfAAAABm+RgooLVsTZeV8lomKZW3JraOeXc86pvAKe9knAABILjJSCCxrLGEfjm0vp1h22cZTvTku/IWyTwAAkEwEUggsyzZZhsE2xLXTli3Rn22cbFR4UfYJAACSjUAKgcYGuSiu7NPKPGPZZRu36wEAAHYFa6QQaGyQi5LKPi0b5aHsEwAAJAoZKWQEC55atCCIQvrLPukSCABAOJCRApCRvPJOWxNlZZ+WiUpm2SddAgEACBcCKQCFsoyKn8ol4z2eVJd9el0CraGFdQm0MkK7bOw4AABAZiGQAuDrzMquHo8FT8kOBAt2CTTe2iwbt2DOD8EoAABIHNZIhQBrNhDk/Zf8djyFoUsgAADhQ0Yqg/ktswD/81tmxW/HUxS6BAIAED5kpDJYEL7Jh7/4LbPit+MpCptDAwAQPgRSGargN/n2Lbn3s41T5oeSMiux0pVZ8dvxFIfNoQEACBdK+zKU902+ZaIKfpNvH/Lser4lR1GZFa/bnM0XC1oss2JBQarnjN+OpzhsDg0AQLgQSGWoTFyz4bd23Jkq1fsvBe14/NAlEAAApB+BVIYK0jf5JaFpRrgzK347HgAAAEMglcGC9k1+UdjoND38llnx2/EAAIBwI5DKYJnwTX5Q2l8DAAAgXOjaFwIWaLRoEcyAIyjtrwEAABAuBFLwtSC1vwYAAEB4EEjB19joFAAAAH7EGin4XqY0zQAAAEDmIJCC72VC0wwAAABkFgIpBAbtrwEAAOAXrJECAAAAgDgRSAEAAABAnAikAAAAACBOBFIAAAAAECcCKQAAAACIE4EUAAAAAMSJQAoAAAAA4kQgBQAAAABxIpACAAAAgDgRSAEAAABAnAikAAAAACBOBFIAAAAAECcCKQAAAACIE4EUAAAAAMSpvEIuEok45+vWrZMf5Obmav369apcubLKlSPORfGYL4gXcwbxYs4gXswZBHnOeDGBFyMUJ/SBlP3STNOmTdN9KAAAAAB8EiPUqFGj2NtkRUoTbmV4BPzbb7+pWrVqysrKSvfhOFGwBXU///yzqlevnu7Dgc8xXxAv5gzixZxBvJgzCPKcsdDIgqhGjRqVmB0LfUbK3qAmTZrIb2wSpXsiITiYL4gXcwbxYs4gXswZBHXOlJSJ8lC4CgAAAABxIpACAAAAgDgRSPlMpUqVNHr0aOccKAnzBfFiziBezBnEizmDeAV1zoS+2QQAAAAAxIuMFAAAAADEiUAKAAAAAOJEIAUAAAAAcSKQAgAAAIA4EUil2OTJk9W8eXNVrlxZnTp10ty5c4u9/TPPPKNWrVo5t2/Xrp1mzJiRsmNF8ObM/fffr8MPP1y1atVyTt27dy9xjiHzxPt3xvPUU08pKytLvXv3TvoxIthz5s8//9TgwYO1xx57OF229ttvP/7/FDLxzplJkyapZcuW2m233dS0aVMNGzZMW7ZsSdnxIr3effddnXTSSWrUqJHz/5kXXnihxPu8/fbbOvjgg52/Mfvuu68efvhh+Q2BVAo9/fTTGj58uNPe8dNPP9WBBx6oHj16aMWKFYXefvbs2TrzzDM1cOBAffbZZ86HGzt9/fXXKT92BGPO2B8dmzNvvfWW5syZ4/zP6rjjjtOvv/6a8mNHMOaMZ8mSJRoxYoQTiCNc4p0z27Zt07HHHuvMmWeffVYLFy50vsRp3Lhxyo8dwZgzTzzxhK6++mrn9vPnz9e///1v5zFGjRqV8mNHemzcuNGZJxaAl8aPP/6onj176qijjtLnn3+uyy67TOedd55effVV+Yq1P0dqdOzYMTJ48OC8yzk5OZFGjRpFxo0bV+jtzzjjjEjPnj3zjXXq1Cly4YUXJv1YEcw5U9COHTsi1apVizzyyCNJPEoEfc7YPOnSpUvkgQceiPTv3z/Sq1evFB0tgjhnpk6dGtl7770j27ZtS+FRIshzxm579NFH5xsbPnx4pGvXrkk/VviPpMjzzz9f7G2uvPLKyP77759vrE+fPpEePXpE/ISMVIrYN3jz5s1zSq085cqVcy5b5qAwNh57e2Pf+BR1e2SWssyZgjZt2qTt27erdu3aSTxSBH3O3HTTTapfv76T/Ua4lGXOvPTSS+rcubNT2tegQQO1bdtWY8eOVU5OTgqPHEGaM126dHHu45X/LV682CkFPeGEE1J23AiWOQH5DFw+3QcQFn/88YfzPxn7n04su7xgwYJC77Ns2bJCb2/jyHxlmTMFXXXVVU49csE/RshMZZkz77//vlNmY6UTCJ+yzBn7EPzmm2/qrLPOcj4Mf//99xo0aJDzpY2VbiGzlWXO9O3b17nfYYcdZpVQ2rFjhy666CJK+1Ckoj4Dr1u3Tps3b3bW2vkBGSkgQ916661O84Dnn3/eWQwMFLR+/Xr169fPWd9St27ddB8OAiI3N9fJYN53333q0KGD+vTpo2uuuUbTpk1L96HBp2z9rmUtp0yZ4qypeu655/TKK6/o5ptvTvehAbuEjFSK2IeU7OxsLV++PN+4XW7YsGGh97HxeG6PzFKWOeOZMGGCE0i98cYbOuCAA5J8pAjqnPnhhx+chgHWSSn2Q7IpX76800Rgn332ScGRI0h/Z6xTX4UKFZz7eVq3bu18g2xlXxUrVkz6cSNYc+a6665zvrSxZgHGuhBb84ELLrjACcKtNBAozWfg6tWr+yYbZZi5KWL/Y7Fv7mbNmpXvA4tdtlrzwth47O3N66+/XuTtkVnKMmfM+PHjnW/5Zs6cqUMOOSRFR4sgzhnbWuGrr75yyvq808knn5zXJcm6PiKzleXvTNeuXZ1yPi/oNt99950TYBFEZb6yzBlbr1swWPICcbf3ABDQz8Dp7nYRJk899VSkUqVKkYcffjjy7bffRi644IJIzZo1I8uWLXOu79evX+Tqq6/Ou/0HH3wQKV++fGTChAmR+fPnR0aPHh2pUKFC5Kuvvkrjq4Cf58ytt94aqVixYuTZZ5+N/P7773mn9evXp/FVwM9zpiC69oVPvHNm6dKlTjfQIUOGRBYuXBh5+eWXI/Xr14/ccsstaXwV8POcsc8vNmeefPLJyOLFiyOvvfZaZJ999nG6EyMc1q9fH/nss8+ck4UfEydOdH7+6aefnOttvti88dg8qVKlSuSKK65wPgNPnjw5kp2dHZk5c2bETwikUuyee+6JNGvWzPmwa+1DP/zww7zrunXr5nyIiTV9+vTIfvvt59ze2kC+8soraThqBGXO7Lnnns4fqIIn+58YwiPevzOxCKTCKd45M3v2bGc7Dvswba3Qx4wZ47TRR3jEM2e2b98eueGGG5zgqXLlypGmTZtGBg0aFFmzZk2ajh6p9tZbbxX6+cSbJ3Zu86bgfdq3b+/MMfs789BDD0X8Jsv+k+6sGAAAAAAECWukAAAAACBOBFIAAAAAECcCKQAAAACIE4EUAAAAAMSJQAoAAAAA4kQgBQAAAABxIpACAAAAgDgRSAEAAABAnAikAAAAACBOBFIAgNA6++yzVblyZX333Xc7XXfrrbcqKytLL7/8clqODQDgb1mRSCSS7oMAACAdVqxYoVatWql9+/Z6880388Z//PFH7b///jrhhBP07LPPpvUYAQD+REYKABBa9evX12233aa33npLjzzySN74oEGDVKFCBd11111pPT4AgH+RkQIAhJr9b/Dwww/XwoULtWDBAr3++us688wzdffdd+uSSy5J9+EBAHyKQAoAEHrffPONDjroIPXu3VvvvfeemjRpoo8++kjlylG4AQAoHIEUAACSRo0apXHjxik7O1tz587VwQcfnO5DAgD4GF+1AQAgqW7dus55o0aN1LZt23QfDgDA5wikAACh9/PPP2v06NFOAGU/jx8/Pt2HBADwOQIpAEDoDRkyxDn/3//+p3/84x8aM2aMFi9enO7DAgD4GIEUACDUnn/+eb300ku6+eabnSYTkyZNUsWKFTV48OB0HxoAwMdoNgEACK3169erTZs2qlevnj7++GOn0YSx1udDhw7V9OnTnQwVAAAFEUgBAELLgqV7771XH374oQ499NC88ZycHHXs2FHLli1z9paqVq1aWo8TAOA/lPYBAEJp3rx5mjx5sgYNGpQviDKWmZo2bZoTSF177bVpO0YAgH+RkQIAAACAOJGRAgAAAIA4EUgBAAAAQJwIpAAAAAAgTgRSAAAAABAnAikAAAAAiBOBFAAAAADEiUAKAAAAAOJEIAUAAAAAcSKQAgAAAIA4EUgBAAAAQJwIpAAAAAAgTgRSAAAAAKD4/H9f5j3JLEheZAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 1000x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final MSE loss: 0.011899\n"
     ]
    }
   ],
   "source": [
    "# Restore the trained model\n",
    "model = brainstate.graph.treefy_merge(graphdef, params_, counts_)\n",
    "\n",
    "print(f\"Model called {model.count.value} times during training\")\n",
    "print(f\"Expected: {total_steps} times\\n\")\n",
    "\n",
    "# Make predictions\n",
    "y_pred = model(X)\n",
    "\n",
    "# Visualize results\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.scatter(X, Y, alpha=0.5, s=20, label='Training data', color='blue')\n",
    "plt.plot(X, y_pred, color='red', linewidth=2, label='Model prediction')\n",
    "plt.xlabel('X', fontsize=12)\n",
    "plt.ylabel('Y', fontsize=12)\n",
    "plt.title('Model Fit After Training', fontsize=14, fontweight='bold')\n",
    "plt.legend()\n",
    "plt.grid(alpha=0.3)\n",
    "plt.show()\n",
    "\n",
    "# Compute final test loss\n",
    "final_loss = jnp.mean((Y - y_pred) ** 2)\n",
    "print(f\"Final MSE loss: {final_loss:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcfedf3d2d0054c5",
   "metadata": {},
   "source": [
    "## Advanced Graph Patterns\n",
    "\n",
    "Now let's explore some advanced architectural patterns using the graph system.\n",
    "\n",
    "### Skip Connections (ResNet-style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "5494b013a7f7b348",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.059453Z",
     "start_time": "2025-10-10T15:54:21.567395Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet Architecture:\n",
      "  Input shape: (5, 10)\n",
      "  Output shape: (5, 10)\n",
      "  Number of blocks: 3\n",
      "  Total parameters: 0\n"
     ]
    }
   ],
   "source": [
    "class ResidualBlock(brainstate.graph.Node):\n",
    "    \"\"\"Residual block with skip connection.\"\"\"\n",
    "    \n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.linear1 = brainstate.nn.Linear(dim, dim)\n",
    "        self.linear2 = brainstate.nn.Linear(dim, dim)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Main path\n",
    "        residual = x\n",
    "        x = jax.nn.relu(self.linear1(x))\n",
    "        x = self.linear2(x)\n",
    "        \n",
    "        # Skip connection\n",
    "        return jax.nn.relu(x + residual)\n",
    "\n",
    "\n",
    "class ResNet(brainstate.graph.Node):\n",
    "    \"\"\"Simple ResNet with multiple residual blocks.\"\"\"\n",
    "    \n",
    "    def __init__(self, dim, n_blocks):\n",
    "        super().__init__()\n",
    "        self.blocks = [ResidualBlock(dim) for _ in range(n_blocks)]\n",
    "        # Register blocks as attributes\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            setattr(self, f'block_{i}', block)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        for block in self.blocks:\n",
    "            x = block(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "# Test ResNet\n",
    "resnet = ResNet(dim=10, n_blocks=3)\n",
    "x = brainstate.random.randn(5, 10)\n",
    "output = resnet(x)\n",
    "\n",
    "print(f\"ResNet Architecture:\")\n",
    "print(f\"  Input shape: {x.shape}\")\n",
    "print(f\"  Output shape: {output.shape}\")\n",
    "print(f\"  Number of blocks: {len(resnet.blocks)}\")\n",
    "\n",
    "# Count parameters\n",
    "params = brainstate.graph.states(resnet, brainstate.ParamState)\n",
    "total_params = sum(jnp.size(p.value) for p in params.values() if hasattr(p.value, 'shape'))\n",
    "print(f\"  Total parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d48d56f17a8fffd",
   "metadata": {},
   "source": [
    "### Multi-Path Networks (Inception-style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "a5d74feeaa8a6be1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.721597Z",
     "start_time": "2025-10-10T15:54:22.060930Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inception Block:\n",
      "  Input shape: (4, 16)\n",
      "  Output shape: (4, 32)\n",
      "  4 parallel paths with concatenated outputs\n",
      "  Total parameters: 0\n"
     ]
    }
   ],
   "source": [
    "class InceptionBlock(brainstate.graph.Node):\n",
    "    \"\"\"Inception-style block with multiple parallel paths.\"\"\"\n",
    "    \n",
    "    def __init__(self, in_dim, out_dim):\n",
    "        super().__init__()\n",
    "        # Path 1: 1x1 (direct projection)\n",
    "        self.path1 = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        \n",
    "        # Path 2: 1x1 -> 3x3 (simulated with linear)\n",
    "        self.path2_a = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        self.path2_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)\n",
    "        \n",
    "        # Path 3: 1x1 -> 5x5 (simulated)\n",
    "        self.path3_a = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "        self.path3_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)\n",
    "        \n",
    "        # Path 4: pool -> 1x1\n",
    "        self.path4 = brainstate.nn.Linear(in_dim, out_dim // 4)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Execute all paths in parallel\n",
    "        out1 = jax.nn.relu(self.path1(x))\n",
    "        \n",
    "        out2 = jax.nn.relu(self.path2_a(x))\n",
    "        out2 = jax.nn.relu(self.path2_b(out2))\n",
    "        \n",
    "        out3 = jax.nn.relu(self.path3_a(x))\n",
    "        out3 = jax.nn.relu(self.path3_b(out3))\n",
    "        \n",
    "        out4 = jax.nn.relu(self.path4(x))\n",
    "        \n",
    "        # Concatenate outputs\n",
    "        return jnp.concatenate([out1, out2, out3, out4], axis=-1)\n",
    "\n",
    "\n",
    "# Test Inception block\n",
    "inception = InceptionBlock(in_dim=16, out_dim=32)\n",
    "x = brainstate.random.randn(4, 16)\n",
    "output = inception(x)\n",
    "\n",
    "print(f\"Inception Block:\")\n",
    "print(f\"  Input shape: {x.shape}\")\n",
    "print(f\"  Output shape: {output.shape}\")\n",
    "print(f\"  4 parallel paths with concatenated outputs\")\n",
    "\n",
    "# Count parameters\n",
    "params = brainstate.graph.states(inception, brainstate.ParamState)\n",
    "total_params = sum(jnp.size(p.value) for p in params.values() if hasattr(p.value, 'shape'))\n",
    "print(f\"  Total parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0aface45892c8ab",
   "metadata": {},
   "source": [
    "## Graph Analysis and Statistics\n",
    "\n",
    "Let's create a utility function to analyze graph structures:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "f0810b6995fbed70",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-10T15:54:22.915378Z",
     "start_time": "2025-10-10T15:54:22.721597Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph Analysis Comparison:\n",
      "================================================================================\n",
      "Model           Nodes    Depth    Param Tensors   Total Params   \n",
      "--------------------------------------------------------------------------------\n",
      "MLP             4        1        6               0              \n",
      "ResNet          16       2        24              0              \n",
      "Inception       7        1        12              0              \n",
      "\n",
      "================================================================================\n"
     ]
    }
   ],
   "source": [
    "def analyze_graph(node):\n",
    "    \"\"\"Analyze computation graph statistics.\"\"\"\n",
    "    stats = {\n",
    "        'total_nodes': 0,\n",
    "        'total_params': 0,\n",
    "        'param_tensors': 0,\n",
    "        'node_types': {},\n",
    "        'depth': 0\n",
    "    }\n",
    "    \n",
    "    def traverse(n, depth=0):\n",
    "        stats['total_nodes'] += 1\n",
    "        stats['depth'] = max(stats['depth'], depth)\n",
    "        \n",
    "        # Count node type\n",
    "        node_type = n.__class__.__name__\n",
    "        stats['node_types'][node_type] = stats['node_types'].get(node_type, 0) + 1\n",
    "        \n",
    "        # Count parameters\n",
    "        params = brainstate.graph.states(n, brainstate.ParamState)\n",
    "        stats['param_tensors'] += len(params)\n",
    "        for p in params.values():\n",
    "            if hasattr(p.value, 'shape'):\n",
    "                stats['total_params'] += jnp.size(p.value)\n",
    "        \n",
    "        # Traverse children\n",
    "        for _, child in brainstate.graph.iter_node(n):\n",
    "            if child is not n:\n",
    "                traverse(child, depth + 1)\n",
    "    \n",
    "    traverse(node)\n",
    "    return stats\n",
    "\n",
    "\n",
    "# Analyze different models\n",
    "models = {\n",
    "    'MLP': MLP(),\n",
    "    'ResNet': ResNet(dim=10, n_blocks=3),\n",
    "    'Inception': InceptionBlock(in_dim=16, out_dim=32)\n",
    "}\n",
    "\n",
    "print(\"Graph Analysis Comparison:\")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{'Model':<15} {'Nodes':<8} {'Depth':<8} {'Param Tensors':<15} {'Total Params':<15}\")\n",
    "print(\"-\" * 80)\n",
    "\n",
    "for name, model in models.items():\n",
    "    stats = analyze_graph(model)\n",
    "    print(f\"{name:<15} {stats['total_nodes']:<8} {stats['depth']:<8} \"\n",
    "          f\"{stats['param_tensors']:<15} {stats['total_params']:<15,}\")\n",
    "\n",
    "print(\"\\n\" + \"=\" * 80)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "840d30c8b943804f",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "`PyGraph` is a data structure specifically designed in the `brainstate` library to provide JAX function transformation support for complex graph structures. Its core value lies in:\n",
    "\n",
    "### 1. Flexible Expression of Graph Structures\n",
    "- Compared to traditional tree structures (pytree), graph structures can represent more complex node relationships and dependencies\n",
    "- Supports cyclic references, making it suitable for expressing complex model structures like RNNs\n",
    "- Nodes can contain arbitrary pytree array data or subgraph structures\n",
    "\n",
    "### 2. Seamless Integration with the JAX Ecosystem\n",
    "- Supports JAX's core functionalities, including automatic differentiation, vectorization, and parallelization\n",
    "- Provides a mechanism for converting between graph structures and pytree via `treefy_split` and `treefy_merge`\n",
    "- Natively supports JAX function transformations such as `jit`, `vmap`, and `grad`\n",
    "\n",
    "### 3. Rich API for Graph Operations\n",
    "- **Structure operations**: `graphdef`, `iter_node`, `iter_leaf`, `nodes`, `states`\n",
    "- **Transformations**: `treefy_states`, `clone`, `treefy_split`, `treefy_merge`, `flatten`, `unflatten`\n",
    "- **Modifications**: `pop_states`, `update_states`\n",
    "- **Conversions**: `graph_to_tree`, `tree_to_graph`\n",
    "\n",
    "### 4. Powerful Filter System\n",
    "- Filter by type: `brainstate.ParamState`, `brainstate.ShortTermState`, etc.\n",
    "- Combine filters: Use tuples/lists for \"any\" logic\n",
    "- Match everything: `...` or `True`\n",
    "- Match nothing: `None` or `False`\n",
    "\n",
    "### 5. Structure preserving state management\n",
    "- Separate graph structure (`graphdef`) from state values (`treefy_states`)\n",
    "- Retrieve states according to state filters\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
