{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f6",
   "metadata": {},
   "source": [
    "# Functional API and Graph Operations\n",
    "\n",
    "This tutorial demonstrates how to use BrainState's functional API for explicit state management using graph operations. This approach provides fine-grained control over model states and is particularly useful for advanced use cases like custom training loops and functional transformations.\n",
    "\n",
    "## Learning Objectives\n",
    "\n",
    "By the end of this tutorial, you will:\n",
    "- Understand the functional API and graph operations in BrainState\n",
    "- Learn how to split and merge model states using `treefy_split` and `treefy_merge`\n",
    "- Build a training loop with explicit state management\n",
    "- Apply JAX transformations with separated states\n",
    "- Track custom states (e.g., function call counts)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5f6a7",
   "metadata": {},
   "source": [
    "## Setup and Imports\n",
    "\n",
    "First, let's import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c3d4e5f6a7b8",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.179979Z",
     "start_time": "2025-10-11T10:20:17.348456Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import brainstate\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e5f6a7b8c9",
   "metadata": {},
   "source": [
    "## Problem: Polynomial Regression\n",
    "\n",
    "We'll solve a simple polynomial regression problem to demonstrate the functional API. Let's create a synthetic dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e5f6a7b8c9d0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.284790Z",
     "start_time": "2025-10-11T10:20:19.183144Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAArMAAAHWCAYAAABkNgFvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAYzFJREFUeJzt3QmYXGWV//HT3dVr9ZKtO52EYMgCARJAQRiCimgg/lWUcXQYYFgiorINAy6ssiibiAyjICgjwowiiIIbyG5GERwUBEkkQBJCgKSXbN2d7vRa9X9+b7xNdXf1Ut213Fv3+3mepqhKLbfqreXcc8973oJ4PB43AAAAIIAKc70BAAAAwHgRzAIAACCwCGYBAAAQWASzAAAACCyCWQAAAAQWwSwAAAACi2AWAAAAgUUwCwAAgMAimAUAAEBgEcwCObZ+/XorKCiw66+/PiuP9/73v9/9IbP+9Kc/2ZIlSywajbrxff755y0oTjnlFJszZ47lgxUrVrjXX6epuvzyy91tMTZ8tyBXCGYROi+++KJ98pOftHe84x1WVlZms2bNsiOPPNK+/e1vZ/RxH3zwQffjmA1/+9vf3GMpUM4k/XDpx15/hYWFVl1dbXvttZedeOKJ9uijj07ovr/zne/YHXfcYX6wceNG93qONSDt6emxT33qU7Z161b7j//4D/uf//kf937LdMDm/RUXF9vcuXPtpJNOsnXr1mXscfF28J/4+peWltqee+5pl156qXV2duZ684C8F8n1BgDZ9NRTT9kRRxxhu+++u5122mlWX19vb7zxhv3xj3+0//zP/7Szzz47o8HszTffnJWAVsHsFVdc4YLNwRm2Rx55JK2Ptdtuu9k111zj/r+9vd3WrFlj9913n/3whz+0f/7nf3anCq7GE8xOmzbNBQp+CGb1euq1POCAA0a9/tq1a+3111+32267zT7zmc9Ytvzbv/2bvfvd73bB9HPPPWff+9737IEHHnA7cDNnzrQwet/73mc7d+60kpKSjD6OAtj/+q//cv/f0tJiv/jFL+xrX/uaey/86Ec/sjBI93cLMFYEswiVq666ympqatwh4EmTJg34t6amJguDdP+o6/X813/91wGXXXvttS6wUkCqAPDrX/+6hYn3Xhr8HpsI7SioZGEk733ve91RB1m+fLnLDmoc7rzzTrvwwgstjHTEQEdgMi0SiQz4HJxxxhmuzOTHP/6x3XDDDTZ9+nTLlt7eXovFYhkP4AfL9uMBHsoMECrKkuy7775Jg4y6urr+/z/88MNt//33T3ofOoy+bNmyIfWuyoLNmzfPZWiUHVPA7FF2UVlZSTwcOdhI9+FZvXq1C1imTJnifqQPOugg++Uvf9n/7zo0r0Pcoiy091hezWCyujYdClXGWMGP7nPGjBn2iU98wr1e41FUVGTf+ta3bJ999rGbbrrJZao8P/jBD+wDH/iAe731PHWdW265ZcDtFQCvWrXK/vd//7d/+71t1qH7L37xi7Z48WKrrKx0pQ3/7//9P3vhhReGbIdKRzTeFRUVNnnyZPda3XXXXQOu89Zbb9mnP/1pF2xoe3T922+/vf/f9bppLLwA0due4UogNNZ6/4jGIXHb5YknnnBBpwJTvQ8//vGP20svvZS0VlMZ9uOPP95t+3ve8x5LlV5nee211/ov0w6GnqOeq7K1Z555pm3fvn3Y+4jH4248tJ2D6X2jnZnPfe5zA8odfvKTn7gdR2Xt9X764Ac/6DL2g91777124IEHWnl5ucvCKxjUeAx+PTXOGzZssI9+9KPu/1Ua5H2elHXW89TrqVKOweObrGb297//vRsbHaHR6zB79mw799xzXQY3XfSYGjO9foNLPX7zm9/0vweqqqrsIx/5iHu/J3t99PnQa7ho0SK7//77h9QzJ34H3Xjjjf3fH3rvjOX7QpTJ15GHBQsWuOtMnTrVbXtiqVBDQ4N7/2tMdf/6jtB7IrGUKdl3i3bsTj31VPf50n3re1U7V4nG+j0KDIfMLEJFP3ZPP/20rVy50v04DEc1nypDGHw9fbG+8sordskllwy4vn5A29ra3I+6vpSvu+46FwzqR0yH2HW5DlXrx0H1k8mMdh+iH7zDDjvM/ZhfcMEF7sdQgcMxxxxjP/vZz+wf//Ef3WFVZeMUTF500UW29957u9t6p4P19fW5IOHxxx+3f/mXf7FzzjnHbYe2Vc9fPyzjDWiPO+44+8pXvmJPPvmk+8EWBa4Kpj72sY+5bNavfvUrl8VSJkmBlehHWSUfClwuvvhid5mX2dLr8fOf/9wFI3vssYc1Njbad7/7XRdA6gfcO5yuQ/x6HfRDruekwOuvf/2r/d///Z8LEEW3/Yd/+Af3ep911llWW1vrAg39+La2ttq///u/u9ftq1/9qqt//OxnP+uCEFHWLRmNn8bn6quv7j/s7237Y4895gJv1bMqYFXwpIBbY6qygMElIXqOCjB0XwqKUuXtjCg4ET2mgpalS5fa6aefbi+//LIbD72v//CHPyQtB9FroyBT70ftSCgo8mjs9Doly8wrI6qdDu3I6LYnnHCCe+092hlQcKTXR2UqGguV+mg7/vKXvwzY4dR7VK+b3tu6Lx2213jp/a/3h+5bn5Vbb73V1Qkfeuih7r0xHAWJHR0d7jXQa/PMM8+4cXjzzTfdv6WLF+hpZ8Sjz//JJ5/sdoh1xELboTFQ8Kjn7b0HVB5y7LHHup02vT7btm1z70u9t5LRTqLe43qPKhDUOI3l+8J7X+gxVBJz8MEHuzH985//7N6Tmk8g//RP/+TuT59LbaOCVH1HaCdjuMmCen8ruNWOjMZLY6LXVwG5dqD0uUz1OxBIKg6EyCOPPBIvKipyf4ceemj8y1/+cvzhhx+Od3d3D7je9u3b42VlZfHzzz9/wOX/9m//Fo9Go/EdO3a486+99poijPjUqVPjW7du7b/eL37xC3f5r371q/7LzjzzTHfZYKncxwc/+MH44sWL452dnf2XxWKx+JIlS+ILFizov+zee+91t/3tb3875PEOP/xw9+e5/fbb3XVvuOGGIdfVfY9E97PvvvsO++/333+/u+///M//7L+so6NjyPWWLVsWnzt37oDLdL+J2+nRc+/r6xvyGpaWlsa/+tWv9l/28Y9/fMRtk1NPPTU+Y8aM+ObNmwdc/i//8i/xmpqa/m3905/+5J7HD37wg/hY6HXX9TUOiQ444IB4XV1dfMuWLf2XvfDCC/HCwsL4SSed1H/ZZZdd5m5/3HHHpfR4Gsvm5ub4xo0b4w888EB8zpw58YKCArf9TU1N8ZKSkvhRRx014PW76aab+m/rOfnkk+PveMc7+s+//PLL7jq33HLLgMf92Mc+5h7De59427H33nvHu7q6+q+n8dflL774ojuvz5teh0WLFsV37tzZf71f//rX7nqXXnrpgG3RZVdffXX/Zdu2bYuXl5e753b33Xf3X7569Wp3Xb1+g1+bxM9CsvfgNddc4+7v9ddfHzIOo9E26ntBr73+1qxZE7/++uvd/ek5eq9PW1tbfNKkSfHTTjttwO0bGhrc+y3xcn3Od9ttN3cbz4oVK9z2JI6N9/1RXV3txjjRWL8v9t9///hHPvKRYZ+fXm89xje+8Y0RX4fB3y033niju90Pf/jD/ss09vruraysjLe2tqb8HQgkQ5kBQkVZBmVmlRXUYWnt+StDosxF4qE3HTrVITTVu3kZMWWH7rnnHpfVGFy7qAxKYvbFy96lMpN8tPtQVkyHqDWpStmLzZs3u78tW7a45/Dqq68OOUQ7FsrQ6BBvsslvE21LpMyqaHs9OqTsUdZOz0FZVT3PxHKE4SjrpKyfNyZ6/noclX8ok+RRZk+ZtuEOU2pc9dyPPvpo9//e66k/vZ7alsT7m6hNmza5bgjKSiVmN/fbbz/3vtQEwcE+//nPp/QYKpdQdlnZaWXCVWerQ7o6tKyscHd3t8s2e6+f6AiESjWUCRyOyk8OOeSQAROZ9H5UFltZ0cHvE2VcE+snB7+XlfVTZk8Z+cR6Vm3zwoULk25L4kQ6ja3GW59DfR48ukz/NtrnLvE9qNdIY65Mu94Hyo6Oh+5Hr73+5s+f77LSyopqIpj3+iiTqYykjlgkvt90FEOv729/+1t3PR3FUfmEsszeZ0j0OVGmNhllTvXYnlS+L/SaKeuqy4Z7vTSeKtVQhnis9J7WJFs9X48yrDpisWPHDldGlO7vUYQTwSxCR4c1NdteX8o6vKiJMfqy1+For85M9EOiQ2iqrxMFAzoUqhKEwVR7l8j7Qk7li3+0+9ChOv3Y6rC996Pp/V122WXjnsSmQ9EKAnTIP930gyWqC/ToMLIOc3s1o9p+lUPIWIJZlSOo3ZUOvyuwVSCu+1AJQeLtzz//fBcI6LCprqsSBj22p7m52QUWqtEb/HoqGEv3pEB1NxC91oOplEGBhgKiRCMdKk9GpRAKmBTE6PVQUOS9X4d7fAUpKnvw/n04+jzo9fOup8PFqrUcz+dhpNdCwezgbVHAmxioeTucqt8cHEjr8tE+d/pcezsVeo/ovr0657G8B5PRNuq1158O+WtM9f5JDJy9YFE1voPfc+oE4L3fvOevoHiwZJcle6+k8n2hMhp9FrTTomD5S1/6knv/ePQ5U0mEdl5UMuOVe6iOdiR6HvrsJe48JZY8DR7ndHyPIpyomUVo6Udcga3+9CWuAEY/0N4XvbIX+uJWayl9eetUWQYFYoMps5JMKnWOo92HgjhRxsebgDbWH7pcUc1t4nYpcNZkIAUsmuGtiTcaB2VwFKB6z3Ekqh/VD7SykGp9pIBEP5bKOCbeXj+Yqgn99a9/bQ899JDLwmrykwI+1Y1611W9p2oYk1HWNJcSA6GxUCCS7P2ZDqqn1iQpZWe186HPgzK+yQLSdHwexnJ/43kcZfOVCVfmUjs8ei9qx0pZSgW4Y3kPDrctia+9PqO6b9V/ekd9vPtW3ay+SwabyA7l4PdKKt8X+n7TZ1NZZAXVajGmz6NqkL2MuD5fOoqhevWHH37YfQZVZ6sdp3e+852WDul+3yA8CGYBM/ej7B0KTvxi1UQhTVRRVkJf4jokO9wX7mgmeshe2TPvMN1oAUsqj6UJXpqYoyxbOidZKGjQhA51EvBm4mvCUFdXl/txT8zCeIdXx/IcfvrTn7ouDd///vcHXK7MkrK0iRSk6NCl/nSIXZNJNMte2XhlqJQx1nam8/UcjrdoggLswTTjXNs+WuutdD2+914SvS7qdjDaa6CdBpUBKJhVaYGytJqoN9Ft8ToueHRZJheY0OF7TeJU+YWyzZ6JLvIxmGb7K/jXjpP6WGuioTeZUp08Rnq9veefrANEsssm+n3hja926PWnIyoKcDUxLLG8Q9v/hS98wf0py6yey9/85jfdjs1wz0MZXgXWidlZvd8TnycwUZQZIFQUNCXby/fqFQdnmXQIVYe4lF3RF/zgWdup8AKVkdogjUQ/gJoZrJn7iUF34mHz8TyWau10iFsttNKVEVGAqLo4tZzSqWoyxdsRSLxfHdbVYdnB9BySbb/uY/B2KaM+uF5YtYGJlAFWmyPdVoG77kfPXRlbL4OcjtdzpOBGP/4KohLvR4+tbNiHP/xhyyQFNHoN1OUi8fXTToHGwOs2MRJ9HlSKo8PQev2UrR3vzqPez8r8aefGo8PYes+MZVvGK9l7UP+vTgrppjp07cypu4MoQ6rPgo4u6D043HtONc/qovLf//3f/aU6ohpTBePp/r4Y/FlR6YWytt7YqOPC4JXMFNhqZzBx/AbTe1qlCJprkNgDV50j9BheaQcwUWRmESr6cdEXs1rS6BCgslJaFUxftmov49VKenT4TD8qCpZ02Ppd73rXuB9b/TRFwZ1+1MYTDKi3prKcOpysLLGyL6rj1aQ2TXbyeq0qaNL9K6OsQEU1b15v18GUndKP5nnnnedqiDXpQrWbqhHWBJ1k/UUT6f69zIxeW28FMB221PNTKYDnqKOOcgGVDld6OwhqoaXtGvyDq9dLLYuuvPJK98Oq6+g5qI2Yavw0Vpq0ox93ZQsTs43eY+lQribhqFxEQZICdgVKXg2vggzt4GjyjV5PBbs6/KyJX3r++n/vh1v1vQq+dFsFt7pNqjWt3/jGN1yLKbWOUpslrzWX6jwzvTKcMtHKSCtT+KEPfchNglQWVKUXKrUZy46aXju1stLnQc8j2ftpLJQt1HtTY6iARhOEvNZc+hwqo5kp+txrPHX4XTtACi61Q5OJuky9VnqOeo31/tN3iN7T2inQd4k+HxoX1fBq0pveq95OpQJeffZ0me5D26d/0/dRYoCbju8Lve8V+OozpwytJujpCIjaaYky2SoP0mQyXVflEOp5q/sa6TtMbcIUTKt849lnn3Vjq/v1svqJtfTAhCTtcQDkqd/85jfxT3/60/GFCxe61jBqVTR//vz42WefHW9sbEx6m+uuu25IayCP11ImWcuawS2Cent73ePU1ta6lj3exy+V+5C1a9e6Nk719fXx4uLi+KxZs+If/ehH4z/96U8HXO+2225z7a7UhiyxNdHg9jleq6KLL744vscee7j71H1/8pOfdI81Et2P7tv702uqlj//+q//6tqgJfPLX/4yvt9++7nWZ2rr9PWvf72/PZhei8R2RWoXVFVV5f7N22a1GfrCF77gWmqpPdNhhx0Wf/rpp4c8r+9+97vx973vfa7dj9p2zZs3L/6lL30p3tLSMmB7NO5qmzZ79uz+566WRt/73vcGXE9tgvbZZ594JBIZtU3XcK255LHHHnPbrG1XO6Wjjz46/re//W3AdbyWUGrzNBYjPd5gasWl97+e6/Tp0+Onn366a72UaHBrrkRnnHGGe6y77rprzNvhvccHv2b33HNP/J3vfKcbnylTpsRPOOGE+Jtvvpm07dVY28JpuxPbTCVrzaXXe+nSpe79Om3aNNcSSy3SBm9jqq25ktFnSJ9BXSdxm9SOTu249DnQe/OUU06J//nPfx5wW7Ud01jp9VGLL312/umf/sld5hnp+2Os3xdXXnll/OCDD3Ztw/S+1P1fddVV/S0L1bpOnxFdruep7T7kkEPiP/nJTwY8VrLvFn2+li9f7l5nfd+qVdjg90Gq34HAYAX6z8TCYSC/KVukTJEaoA+ebQuEjT4LKk3Q4WMdQkd26aiLsrnprvEFgoyaWWAE2tfTD7cOhRLIIuxUN6mSEtUaE8hmlmpqVV+aSH1eVRoweMlYIOyomQWSUM2oZtyrnlI1mWpZA4SV+pGqhlj1jposNHgZUqSf6nk1aU+1zJoQpg4AqtlWHXiqi2kA+Y5gFkhCM33VlkuTftRTU5NlgLBSBwO149KEL3VD0KFuZJYWDNCELPV81feRJh1qAp4mLWpiGYC3UTMLAACAwKJmFgAAAIFFMAsAAIDACl3NrJbV27hxo2vWnI4lKgEAAJBeqoJta2tzEyATl0NOJnTBrALZ2bNn53ozAAAAMIo33njDdttttxGvE7pg1ls+Ty+Ot158pjPBmomqJtej7VnAnxjD4GMMg48xDDbGL/hiWR7D1tZWl3wcy7LHoQtmvdICBbLZCmbVaFyPxQc4mBjD4GMMg48xDDbGL/hiORrDsZSE8o4CAABAYBHMAgAAILAIZgEAABBYoauZHWs7iN7eXuvr60tLjUlPT4+rMwlTnVBRUZFFIhHanwEAgIwimB2ku7vbNm3aZB0dHWkLjBXQqlda2AK7iooKmzFjhpWUlOR6UwAAQJ4imE2goPO1115zWUU16VUQNtEA1MvyhilLqeesnQK18NDruWDBglBlpQEAQPYQzCZQAKaAVn3NlFVMhzAGs1JeXm7FxcX2+uuvu9e1rKws15sEAADyEOmyJMgipgevIwAAyDSiDQAAAAQWZQYAAAAhE4vF7a3tO629u9eiJRGbNancCguDWQ5JMAsAABAia5ra7OGVjba2eYd19vZZWaTI5tVW2rJF021+XZUFDWUGeeKUU05xE8z0p4lX06dPtyOPPNJuv/12N6ltrO644w6bNGlSRrcVAADkLpD9wR/W28qNLTapotjmTqt0pzqvy/XvQUMwm8H0/RtbO2x1Q6u9uW2nO59pH/rQh1yP3PXr19tvfvMbO+KII+ycc86xj370o66jAgAACK9YLO4yslvbu21BXaVVlRVbUWGBO9V5Xf7IqsasxCzpRDCbAdqruWXFWvuPR1+xbz++xv7z8TV2y/+uzfjeTmlpqdXX19usWbPsXe96l1100UX2i1/8wgW2yrjKDTfcYIsXL7ZoNOpakJ1xxhm2Y8cO928rVqyw5cuXW0tLS3+W9/LLL3f/9j//8z920EEHWVVVlXuM448/3pqamjL6fAAAQPq8tX2nKy2YUVM2pF2oztdXl9oLb2y3373a7BJyQQlqCWYznb6vjdrkimJbtbE1J+n7D3zgA7b//vvbfffd198u61vf+patWrXK7rzzTnviiSfsy1/+svu3JUuW2I033mjV1dUuw6u/L37xi+7ftCTv1772NXvhhRfs5z//ucv+qrQBAAAE40jx2uYdtrOnzypKhk6ZUlb2b5taXfzy/SfXuYScEnNBKDtgAlgG0/e79nriVlkWsQXlxfZqU7tL36s+JZszBhcuXGh//etf3f//+7//e//lc+bMsSuvvNI+//nP23e+8x234llNTc2uvbP6+gH38elPf7r//+fOnesC4ne/+90uq1tZWZm15wIAAMY30auvL25vbOuw8uJCmz0l2n89xS3Pv7HdWnf2WFlxke0xtdIiRQUusN3YstOWHzbH5k57+/p+Q2Y2i+l7Xb6maYe7XjZpFTJvex577DH74Ac/6EoRVDJw4okn2pYtW6yjo2PE+3j22Wft6KOPtt13393d7vDDD3eXb9iwISvPAQAATGyi18xJZdYbi9uf1m+zLTs6+2MExSYd3b0WKTSbXl3mrp9YR/vwygZ7Y2u7vbmtw/35rfyAYDaN1KtNez7J0vdSXlJkXb197nrZ9NJLL9kee+zhSgM0GWy//fazn/3sZy5Avfnmm911tOTscNrb223ZsmWu/OBHP/qR/elPf7L7779/1NsBAIDhD/tnqi41NsxEr+ryEjt4zhR3nWde22atO7ttW0e3NbV1Wl9fzCpKIzavNtqfANOpsrgPvNhg1/xmtf3qhY1242Ov+q78gDKDNIqWRFyvNu3d6I0z2M7uPiuNFLnrZYtqYl988UU799xzXfCqNl3f/OY3+5ea/clPfjLg+io16OvrG3DZ6tWrXfb22muvdZPG5M9//nPWngMAAPkgW/1d3xrhSPHUylJ795zJtrphh21s6XQxS2d3n82eWuEC3ynR0v7rKhh+uXGHbW3vsoXTo1ZfXWZFvcUDyg/80JeWYDaNtHqG3pQa5MrSyIA3kNL4m1o6bfGsGne9TOjq6rKGhgYXjDY2NtpDDz1k11xzjcvGnnTSSbZy5Uo3kevb3/62Kxn4wx/+YLfeeuuA+1AdrepgH3/8cTdxrKKiwpUWKMjV7VRfq/vRZDAAAJDaYX8FiAoyK0rKXSCZicCwvf9IcfJ4Y8akcuvs6bN/fvfu7vyPn9lgM2vKrbr87UScV36wo7PXasqLraa8xAoLulyyrrJM84B25GQeUDKUGaSRBlN7V1OiJW6Q2zp7XG1KW2evm/yly4/ad3rGBl3B64wZM1xAqp6zv/3tb91ELbXnKioqcsGpWnN9/etft0WLFrmSAQW7idTRQAHrsccea7W1tXbddde5U7X2uvfee22fffZxGdrrr78+I88BAIB8k+3+rtGEI8XJ6EhxWbFKCirtfQtqbb9Zk6yhtdMFsB7FLtvau6zA4i6bW1UW8cU8oGQK4olbHgKtra1uxr56qaoGNFFnZ6e99tprrr60rKwsbYcRSgoLbEF9tS3bN5jLxI1Xul7PXFNphnrq1tXV9ZdnIFgYw+BjDIMt7OOn2li1uvImVg2m5Nf2jh4798g9bfaUigk/XiwWd3Wtyvq+3V1pF4V9SrjpSPHnD5/nEmyDs8aa4/PWtp329LotNq2yxN61+2SbUlFiFX07rKOoUtGs9cZitn5zu539wQW2sH5gPJXpeG0wygwyQAHr3PdXur2VHV09VlZUYLtPrbSiovB9gAEACLvRDvsreGxs7UzbBPHCvx8pVvmCAlcvQFVGViWPg48UK25RmYOXiNO29PbFbVq0xPaaXrWrjnZQ7jMX84CGk/styFN6g2jvSntAWko21/UkAAAgN6I5mCA+P0mAqsdQRvaoJEeKExNxCqrLi4tc9wIt+uRafCZcNxvzgFJBMAsAAJCHE8TnDwpQoyUR9xjDJdi8RJznQ4vq3ba57G51qZVGNA+oxza1dmV8HlAqCGYBAAAyKNXD/ul+7NnjrMMdmN1ts1i801oLhs/u5grBbBIhmxOXMbyOAACM77C/X8z/e3b3zW3t/ZP4dpsc9UVG1kMwm6C4eFcdi5Z2LS/PfQ1I0HlL5HqvKwAAYZbqYX+/KCwssN0mV1hJT4XVTa7w3fYSzCZQL9ZJkya5PQ/RggGDV85IlTcBLBIZWCOTz/ScFcjqddTrqdcVAABM7LA/kiOYHaS+vt6degFtOgI79ddTX72wBLMeBbLe6wkAAJAJBLODuFUtZsxwNSFa+nWiFMhu2bLFpk6dGqpG0SotICMLAAAyjWB2GArE0hGMKZhVYKcVsMIUzAIAAGQD0RUAAAACi2AWAAAAgUUwCwAAgMAimAUAAEBgEcwCAAAgsAhmAQAAEFgEswAAAAgsglkAAAAEFsEsAAAAAotgFgAAAIFFMAsAAIDAIpgFAABAYOU8mL355pttzpw5VlZWZocccog988wzI17/xhtvtL322svKy8tt9uzZdu6551pnZ2fWthcAAGRPLBa3N7Z22OqGVneq837i9+0Lg0guH/yee+6x8847z2699VYXyCpQXbZsmb388stWV1c35Pp33XWXXXDBBXb77bfbkiVL7JVXXrFTTjnFCgoK7IYbbsjJcwAAAJmxpqnNHl7ZaGubd1hnb5+VRYpsXm2lLVs03ebXVU34/hV4vrV9p7V391q0JGKzJpVbYWGBb7YPAQhmFYCedtpptnz5cndeQe0DDzzgglUFrYM99dRTdthhh9nxxx/vziuje9xxx9n//d//ZX3bAQBA5ihQ/MEf1tvW9m6bUVNmFSXl1tHdays3ttjGlp22/LA5EwoYJxqIZnr7EIBgtru725599lm78MIL+y8rLCy0pUuX2tNPP530NsrG/vCHP3SlCAcffLCtW7fOHnzwQTvxxBOHfZyuri7352ltbXWnsVjM/WWaHiMej2flsZAZjGHwMYbBxxiGa/yUMX34xQbb1t5lC2or3RFYqSqNWGVt1NY077BHVjbYnPdVpJRJ9axtbrM7nnrdtrV3W321AtEyF4iu2rjdNrZ02ClL3mHzaqtytn1+FMvyZzCVx8lZMLt582br6+uz6dOnD7hc51evXp30NsrI6nbvec973Ava29trn//85+2iiy4a9nGuueYau+KKK4Zc3tzcnJVaWw1GS0uL214F6wgexjD4GMPgYwzDNX7NbV22fWuz7VkdsfJY+5B/37M6Ztu2NNtLr5VYbVVpitsSt9+tbLCiznbbb2qZFViXWbzLqorN6qaabdzear9/YZ1FF9UPG4hmcvv8Kpblz2BbW1swygxStWLFCrv66qvtO9/5jquxXbNmjZ1zzjn2ta99zb7yla8kvY0yv6rLTczMauJYbW2tVVdXZ2Xwtcemx+MLOJgYw+BjDIOPMQzX+G2Lt9rG7mbbo6rSOpIElL0FMdvU1m6Ryhqrq0vtt/zNbR22aluDTYpOsp1FxUP+PRIts5XbemxpaZXtNrki69vnV7EsfwbVGMD3wey0adOsqKjIGhsbB1yu8/X19Ulvo4BVJQWf+cxn3PnFixdbe3u7ffazn7WLL7446YtbWlrq/gbTdbP1hajBz+bjIf0Yw+BjDIOPMQzP+FWWllhpJGIdPX1WVTY04NzZE7OSSMRdL9X3Q0dPzDp7Y1ZRWqyNGvLv5aURa2zrctcb7r4zuX3ZmLgWhM9gKo+Rs2C2pKTEDjzwQHv88cftmGOO6Y/6df6ss85KepuOjo4hT04BsSjtDQAAgk/BmSZjaTJVZWmkvybV+73f1NJpi2fVuOulKloScZO9VCObNBDt7rPSSJG7Xi62bzR0UPBZmYEO/5988sl20EEHuQldas2lTKvX3eCkk06yWbNmubpXOfroo10HhHe+8539ZQbK1upyL6gFAADBpiyjgjN1BXi1aYfrFlBeUuQCTQWKU6IldtS+08eVjUxHIJrJ7RsJHRR8GMwee+yxbiLWpZdeag0NDXbAAQfYQw891D8pbMOGDQMysZdccol70+n0rbfecnUbCmSvuuqqHD4LAACQbgrKFJx5WcjG1k6XMVWgqUBxvEFbugLRTG3fcOUE5cVF9tDKBhfILqhL6KBQVuyCcj2XR1Y12txplXnTQWGsCuIhOz6vCWA1NTVuRl62JoA1NTW5RSCo8womxjD4GMPgYwzDO36Zqg9NPFzf1burtGB+XWXKgehw25fuBRn6+uL2xrYOW1hfZbOnRIdcv62zx7Z39Ni5R+5ps6ckn7gWpM9gKvFaoLoZAACAcFEAmIngTAHr3PdXTjhQTrZ9mViQQV0YtrR328uNbRYtjdiU6MDJ7eUlRS47rOcSNgSzAAAglDIRKE+0rtUtyLCycUg5weSKEptcXmztnb22trndnU+s9905holr+YpjNQAAAGkwOBBVPWtRYYE71XldrrpWXW84yhQro6tAODFYrSqL2ORoqcWtwLbs6LK2zt4hE9fm11VmpIOC3xHMAgAApMFwgajovC5f07TDXW84KhNQaULFoAyrbq9gtbIsYi07e2xbR7f1xmKuVlaTvzLVQSEIwpeLBgAAyIC3A9Hk2dGx1LVGR+iDq4B1r+mVtjquhRn6bP3m9rR3UAgiglkAAIA0iGZhQQatLvaRxfV29P4z3Qpk0SyuAOZXBLMAAAAJxttWK1sLMixbVG+7Tx3aniusCGYBAADS0FYrSAsy5BOCWQAAgDQtF5uuQDRdfXDDgGAWAACE3nD9XcezXGwmF2TAUASzAAAg9FJpqzWWAJNANHvoMwsAAEJvuP6uHtW+dvX2hXK5WL8jmAUAAKEXTWirlUyYl4v1O4JZAAAQel5bLXUdUButRGFfLtbvCGYBAEDoeW211D5Lk720TCzLxQYDwSwAAEBCW61FM2tse0ePWy5Wp2qrNZa2XMgNCj8AAAD+jv6uwUMwCwAAkMO2WuNdPhe7EMwCAAAEcPlc7EIwCwAAENDlc8EEMAAAgJwvn6tlc4sKC9ypzutyLZ+r62FkBLMAAAA+Xj4XIyOYBQAAyDKWz00fglkAAIAsi7J8btoQzAIAAGQZy+emD8EsAABAlrF8bvoQzAIAAOQAy+emB4UYAADkGVaUCg6Wz504glkAAPIIK0qFc/ncWIh3YAhmAQDIE6woFU5rQr4DQzALAEAerijlNeLXilKVpRE3qUgrSs2dVhmajF0YrGEHhglgAADkA1aUCh+WxN2FYBYAgDzAilKZpYDwja0dtrqh1Z36IUBkB2YXygwAAMgD0YQVpZSZG4wVpfKvJvXtHZjyYXdgGls7834HhswsAAB5IB9WlPJj9tOrSVUN6qSKYldzrFOd1+X691yJsiSuk9/PDgCAkK0opUk/muylQ8zKzCmgUSDr9xWl/Jj99PukOm8HZuXGFrc9iaUG3g6MFmDw8w5MOpCZBQAgTwR1RSm/Zj/9XpPKkri7kJkFACAAxtoUP2grSvk5+xmEmtT5f9+B8bLa2h6VFmgHRoGsX3dg0olgFgAAn0v1EHw6VpTyY/Yz288pGpBJdfMDtgOTbgSzAAD4WL43xfdz9jNINamFAdqBSTdqZgEA8KkwNMWP+nhGPjWpwUAwCwCAT/l9AlIYWooFdVJdmFBmAACAT/n5EHyYWoqFvSbV7whmAQDwqWhAJiCFYUZ+mGtS/S7Y734AAPJYkCYgTRTZT4wXwSwAACE+BD/W/rXZQPYT40EwCwBASA/B+3EJ2UzwU8CO9COYBQAghIfg871/bdgC9jAjmAUAIADSeQjez0vIplNYAvawo88sAAAhE4b+tWFYcAK7EMwCABDa/rXJD9BqkllXb1+g+9eGIWDHLgSzAACETNTHS8imSxgCduxCMAsAQMj4fQnZdIiGIGDHLgSzAACEtH+t+tRqsldbZ4/1xmLuVOf9sITsRIUhYMcuBLMAAIS4f+2imTW2vaPH1m9ud6fqX5sPs/zDELBjF3LrAACEVL4vIZvJBSfgHwSzAACEWFCXkB3rql75HrCDYBYAAOT5ql5BDdgxNgSzAAAga5nSiWJVLwxGMAsAALKaKR2vsCzDi9QQzAIAgIxmSudOi6Yli5vKql6UFYRHzltz3XzzzTZnzhwrKyuzQw45xJ555pkRr799+3Y788wzbcaMGVZaWmp77rmnPfjgg1nbXgAAkDxTqgxpUWGBO9V5Xa5Mqa6XGPzesmKt/cejr9i3Hn/Vneq8Lh8Nq3rBd5nZe+65x8477zy79dZbXSB744032rJly+zll1+2urq6Idfv7u62I4880v3bT3/6U5s1a5a9/vrrNmnSpJxsPwAAYTbWTKkytCVmtra5ze54asO4612jCat6KWBOtqpXSVGhte7ssdUNrXQuCImcBrM33HCDnXbaabZ8+XJ3XkHtAw88YLfffrtdcMEFQ66vy7du3WpPPfWUFRfvehMrqwsAALLv7Uxp+bCZUvV21fUi8bg98lLThOpdvVW9FPzqNokBtFb1erVxh1mB2Y//b4N19cUyVrsLf8lZMKss67PPPmsXXnhh/2WFhYW2dOlSe/rpp5Pe5pe//KUdeuihrszgF7/4hdXW1trxxx9v559/vhUVFSW9TVdXl/vztLa2utNYLOb+Mk2PoQ9YNh4LmcEYBh9jGHyMoT9VFBdaWaTQOrp6kmdKu3qtNFLorrd5a5et3dxmM6pLFW8q+uy/ns7rcpUavLmt3XabPHy961H71trGlg533frqMhcwKyOrDHBDa6fL+E6uKHalCMrgrtq43V3/lCXvsHm1BLRB+Qym8jg5C2Y3b95sfX19Nn369AGX6/zq1auT3mbdunX2xBNP2AknnODqZNesWWNnnHGG9fT02GWXXZb0Ntdcc41dccUVQy5vbm62zs5Oy8ZgtLS0uDeAgnUED2MYfIxh8DGG/hSJxW3fyWavb9lu5cVlVrArTHXiFrdt7Z22aGrUIl07rKVlu1XH2602ErPCvreTTJ7SSNxi8U5ramqykp7hg1mFo5/cu8qe27DdGlparb0jZpHCAqsv7bX62gKbOy1iBdZlFu+yqmKzuqlmG7e32u9fWGfRRfWUHATkM9jWNnoNdSC7GeiFVL3s9773PZeJPfDAA+2tt96yb3zjG8MGs8r8qi43MTM7e/Zsl9Wtrq7OyjbrMIgejy/gYGIMg48xDD7G0L/eV1Rhdzz1uv11S7fVV5f2Z0qVJZ0crbb37v8Omz41alvau621IGaFvSVJs7htnT3WWlDkfufrRsjMiqbVLJ6/u6uzVQlD685eu+eZDTaposR2Fg2970i0zFZu67GlpVUjZn3hn8+gGgP4PpidNm2aC0gbGxsHXK7z9fX1SW+jDgaqlU0sKdh7772toaHBlS2UlKi8fCB1PNDfYBqIbH0havCz+XhIP8Yw+BjD4MvVGGZrMQC/GevzXjC9xpYftkd/n9nGti4rjRTZolmT7Kh9d9WqKhCaVllq86ZV2cpNrVZZVjyk3nVTa5ctnlVju02Ojun11dtg96mV7v812auzL24VpcV6owy5bnlpxG1XR0+M74CAfAZTeYycBbMKPJVZffzxx+2YY45xl+nNrvNnnXVW0tscdthhdtddd7nreU/ylVdecUFuskAWAIAgLAYQ9Oety+a+v3LE4Ff/f9SiOtvY2ukme6m21cvibmrptCnREhf8jmdHYSxdDhRg63rIPzndPdHh/9tuu83uvPNOe+mll+z000+39vb2/u4GJ5100oAJYvp3dTM455xzXBCrzgdXX321mxAGAEAmFgPQzPlJFcVulr1OdV6Xj6Uvapiet4JQLVSwsL7anSYLSjUBS+23Fs2sse0dPbZ+c7s7VUZ2IsvQel0OFBQry5vIZX1bOm1+XaW7HvJPTndRjj32WDcR69JLL3WlAgcccIA99NBD/ZPCNmzYMCDNrFrXhx9+2M4991zbb7/9XJ9ZBbbqZgAAQLqEddnUbDzvsWRxU6XbKmusGtp0Z33hfznPt6ukYLiyghUrVgy5TK25/vjHP2ZhywAAYRXWZVOz9by9LG46KUhWdre/dre105UWKOvr1e4iP+U8mAUAIMiLAeTTZLHxPm+/yETWF/5HMAsAwCDRcUwoyofJYtE8mEiViawv/I3+FAAATHBCUb5MFmMiFYKIYBYAgGEmFGnikCYUqaF/byzmTnU+cULR4ElTymgWFRa4U53X5Zo0pevl0/MG/IJgFgCAESYUjdZGKpVJU/n0vAG/8G/RCwAAAZhQFPRJU6k+76BPckP+IZgFAGACE4qieTBpaqzPOx8muSH/UGYAAMAEhGXSVL5MckP+IZgFAGACwjBpKp8muSH/EMwCADBB+T5pKt8muSG/BKuABwAAn8rn1afycZIb8gfBLAAAPlp9yo/dAvJ1khvyA+86AAB8wq/dArxJbprsVVkaGVBq4E1yU0lF0Ce5IZgIZgEA8FG3AE2mUg2qDukrE6oAcmPLzpzW3nqT3LQdmtSm7VNpgTKyCmTzYZIbgosJYAAA5FgQugXk+yQ3BBeZWQAAciyVbgETrcmdiHye5IbgIpgFACDHgtQtIB2T3IB0IpgFACDHouPsFuDHzgdAthHMAgCQY+PpFuDXzgdAthHMAgCQ4SznaPedarcAP3c+ALKNYBYAEHqZzHKO9b69bgHedVUjq9ICZWQVyHrXHdz5wMviqjxBWV0Fw+p8MHdaJSUHCAWCWQBAqGUyy5nqfY+lW0BQOh8A2UIwCwAIrUxmOcd736N1CwhS5wMgG1g0AQAQWqlkOf1y39GEzgfJDNf5AMhXBLMAgNB6O8sZGTbL2dXbN64sZ6bu2+t8oIlh6nSQyOt8ML+uckDnAyCfEcwCAEIrmsEsZ6bu2+t8oA4HKlVo6+yx3ljMner84M4HQL4jmAUAhFYms5yZvG+v88GimTW2vaPH1m9ud6fqfEBbLoQNBTUAgNBKtb+rX+57rJ0PgDAgmAUAhNpY+7v67b7H0vkACAOCWQBA6GUyy0kGFcgsglkAADKc5SSDCmQOwSwAwHe04EBiJnNGdWmuNwmATxHMAgB8RUvAejWm6tOq9lbzpkVtyayI1dXleusA+A3BLADAV4HsD/6w3i0Bq9n/WrJVfVpXbWqxjtZeq5w8xRZMr8n1ZgLwEfrMAgB8U1qgjKwC2QV1lVZVVmxFhQXudH5tpbV19tqjq5rc9QDAQzALAPAF1ciqtEAZ2YKCgTP9dV59Wdc073DXAwAPwSwAwBc02Us1shXDLO9aEim07t4+dz0A8BDMAgB8IVoScZO9VCObTHdvzEoiRe56AOAhmAUA+IIWEphXW+mWeo3HB9bF6rxqaVU7q+sBgIdgFgDgC1pYYNmi6a429tWmHdbW2WO9sZg7Va1sVVnEjty3jpWzAAxAMAsA8A0t/br8sDm2aGaNbe/osfWb292pzi/de7rNq63K9SYC8BkKjwAAvgto576/csgKYJs3N+d60wD4EMEsAMB3VEowe0pF//lYLJbT7QHgXwSzAACkSAs3JGaONSmNWl4gNwhmAQBIccldrVSmBR7UF1ftxNSFQZPXVCIBILsIZgEAoyIT+XYg+4M/rHdtwrRSWUVJueuLu3Jji21s2ekmrxHQAtlFMAsAGHcmcu60ytAEuQro9TookF1QV9m/5G5VWbFVlkZcO7FHVjW61yRfXwPAjwhmAQDjykS+1NBqdVWlrnVWGA63K2hXQK/XwQtkPTqvy9c07XDXS5y8BiCz6DMLABhTJlIZyKLCAnc6NVpiL7yx3Z58dbPVlBe7bOSkimIX5Cr4VRCcb5R9VtBeMcxyuuUlRdbV2+euByB7CGYBACllIrW07NrmdissKDDvYi/IVdCr4FeH2xUM55NoScRln5WZTmZnd5+VRorc9QBkD8EsACClTGRbZ69t6+h2mdhYPG7dfbFhD7fnE9UDq4xiU0unC+gT6bwun19X6a4HIHsIZgEASUWHyUQqeO3ti1nc4lZUWGglRYWhONyuSV2qB54SLXGTvdo6e6w3FnOnOq/Lj9p3OpO/gCwjmAUApJSJVPAaKSyw1o4eF8BVlUVCc7hdE9vUfmvRzBo38W395nZ3unhWTcptuVSG8cbWDlvd0OpO860sA8iW/PumAQCkNROp/qnKPKp8QFlXs7ipsKAvbjZ3WsWQeloFvwru8vVwuwLWue+fWEsyFl4A0odgFgAwaibSC7waWztd1vW986dZY1uXbWnvsZJIkQtylZFVIBuGw+16buNtv8XCC0B6EcwCAMaViVy3eceQIFcZWQWyBGPJsfACkH4EswCAcWUi03G4PWzCtvCCtwzyjq5u693RZdOmxa2Q2TpIM4JZAEBODrcPDnjCEBC/3e4seT2xyjWU5c6HThCJdcFdvb02s6Tbfreh25Ytridzj9wEsxs3brSZM2em99EBAKEWtolQ0YR2ZyotGCxfOkEMqQsuLrNId6ut2tRiG1s7qQtGWo052b/vvvvaXXfdld5HBwCElhfwaOKTFmAIw5K4YVh4YbhlkMuLi2x+bf6uEIcABLNXXXWVfe5zn7NPfepTtnXr1sxuFQAgrw0X8OT7krhhWHghlbpgIKvB7BlnnGF//etfbcuWLbbPPvvYr371K0uXm2++2ebMmWNlZWV2yCGH2DPPPDOm2919993ug3HMMcekbVsAAJkX5oAnnQsvBGkZ5HxfIQ65k1JRzh577GFPPPGE3XTTTfaJT3zC9t57b4tEBt7Fc889l9IG3HPPPXbeeefZrbfe6gLZG2+80ZYtW2Yvv/yy1dXVDXu79evX2xe/+EV773vfm9LjAQByL0wToZLJ504Q0ZDUBcM/Un4nvf7663bffffZ5MmT7eMf//iQYDZVN9xwg5122mm2fPlyd15B7QMPPGC33367XXDBBUlv09fXZyeccIJdccUV9vvf/962b98+oW0AAGRXlIAnLZ0g/FwXrNpn9c4N2wpxyL6UviVuu+02+8IXvmBLly61VatWWW1t7YQevLu725599lm78MIL+y8rLCx09//0008Pe7uvfvWrLmt76qmnumB2JF1dXe7P09ra6k5jsZj7yzQ9hj682XgsZAZjGHyM4dipRlWrUHnZwpk1mckWzqgutXnTom52e2VJ0ZCAp6FlpzsMr+t539eMYXActW+tbWzpcJP46qvLrLy40HZ299ratjabEi21I/dR/BDPu5rofBbL8mcwlccZczD7oQ99yNWyqsTgpJNOsnTYvHmzy7JOnz59wOU6v3r16qS3efLJJ+373/++Pf/882N6jGuuucZlcAdrbm62zs5Oy8ZgtLS0uDeAAnUED2MYfEEYQ/2ob2nvdrWEykhOjZZk/ZDzpu077bkN210g2d0Xs5KiQquvKbd37T7JZmQgi7ZkVsQ6Wntt+9bNbuJTSaTQuntjtq292/aIRuzQmRHbvLk5MGOIt6nq95N7V/39/dRq7R19NiXSawfW1dg7d6+yqvhOa2rKv3rofBbL8mewra0t/cGsgk5NANttt90sV/TETjzxRJchnjZt2phuo6yvanITM7OzZ892WeXq6mrLxuAr46DH4ws4mBjD4PP7GK5tbrNHXmqytZt39Aez86ZV2lGL6mxebVXWtuGnLzXZtvY+q6+utqqSiCsBeLap09a1t9kpS6akfVs0LaJy8hR7ZOWu597d22clkWKbP32yHbnvwOc+2hhmK6OM1MZ38fzd3bjs6Oqx3h0ttnDOLItEinK9aQjA96iaAqQ9mH300Uct3RSQFhUVWWNj44DLdb6+vn7I9deuXesmfh199NFD0tCq3dWksXnz5g24TWlpqfsbTAORrR81DX42Hw/pxxgGn1/HUIdh73hqw9vN5f8eRK7c1Jq15vIKBB9Z1Wxb23tsQV1V/yH/qvISqywrdi2jHv1bs807vDrtAeKC6TU2r7Z6TBOhhhvDsC28ECQaqt2nVrrf6qamLhfI+u0zCH9+j6byGDl9R5WUlNiBBx5ojz/+eP9lesPr/KGHHjrk+gsXLrQXX3zRlRh4fx/72MfsiCOOcP+vjCsABIVfeq3muk2WNxFqYX21O00lYA7jwgsABsr5NFGVAJx88sl20EEH2cEHH+xac7W3t/d3N1B97qxZs1ztq1LOixYtGnD7SZMmudPBlwOA36USRGZy1ntQ22QN3hnozyiXFbtZ9Mooa2dAAS4lB0D+ynkwe+yxx7rJWJdeeqk1NDTYAQccYA899FD/pLANGzZwSAJAXvJLEBkNaJssv+wMAMgtX3wznXXWWe4vmRUrVox42zvuuCNDWwUAmRX1SRAZ1L6gftkZAJBbpDwBIEe8IFLBooLGRF4QOb+uMuNBpA7Ba7KU2mPp0HxbZ4/1xmLuVOd1+VH7Tvfdofpows5AMn7NKANIL4JZACPWJL6xtcNWN7S6Uxqc528QqVn/6pyghQq2d/TY+s3t7lQZ2Wx0VAjyzgCA3GJ3FUBStDvKDi+I9F5rHRZXNlFBpALZbL7Weqy5768cU5ssP+0MqI+pgn/VyKq0QBlZBbJ+zSgDSC+CWQDDtjt6u/dp+a7epxtbXODg10xdUPkpiPTaZAWFn3YGAOQGwSyAAWh3lBtBCyL9xE87AwCyj2AWwAC0O0IQsTMAhBcTwAAM0+4o+b6uahK7evtodwQA8AWCWQADRGl3BAAIEIJZAAPQ7gh+Rrs4AIORWgEwAO2O4Febtu+0n61aZ2s3t9MuDkA/glkAQ9DuCH6ztrnNHnup0V5rj1h9TTnt4gD0I5gFkBTtjoJHh9zzcbz0vB5Z2WRtnb02v3aSFRTuqpCjXRwAIZgFMCzaHQVHEFZsG2+w7drFbd5hu0dLLE67OACDEMwCQMAFYcW2iQTbCn7VDq40UmidSf5dNd0qhaFdHBBOBLMAEGBBWLFtosF2tCTiara7enuS/mrRLg4IN1pzAUBIVmzzQ7CtILuosMCd6rwuV7A9Uost1y5u2q7r0i4OwGAEswAQ4L6qfl+xLR3BtjLKRy2qs6qyiK1p3mFtnT3WG4u5U2WeaRcHhBvHZADAp8ZSZxpNWLFN2U6/HYJ/O9hOnjUda73rvNoqW7r3dHvqrV7XZ5Z2cQA8BLMA4ENjrTP1VmzT5aqRTcx+eofgFfDl6hB8NI3B9oxJ5fa5+bW2qbUr79qPARg/glkACPikLj+v2JbuYJt2cQAGo2YWAHwm1TpTb8W2RTNrbHtHj63f3O5OFSTmui2XF2wrqFawTb0rgHQjMwsAPjOeOlM/r9jG8sgAMolgFgB8JjrOOlM/H4L3c7ANINgIZgGEpg71zW0d1tET830g5fdJXeOV7WB7uOVzx7usLgB/IpgFkPfWNrfZ71Y22KptDdbZG0tpKdVc8PukriC3NVs4o8pWb2ob17K6APyJYBZA3gc1dzz1uhV1ttuk6CSrKC1OaSnVXKHONP1tzf64bovd//xb7jJ1iUh1WV0A/kQwCyDvW1xta++2/aaW2c6iYrUDGLbFld9QZ5q64dqaaby9Lgq1lSX95RtBeS8AGB6tuQDkfYur+uoyK7DxLaXqlzrThfXV7pRAa3xtzdo6e21bR49NjZa4U50P2nsBQHIEswBC0OIq+UEo1aF29faNupQqgj/m3X0x6+2LuTHvi8Xc+US8F4DgIpgFkLeiCS2ukkllKVUEQ3SYMS8pKrRIUaEb86LCQnc+Ee8FILgIZgHkLa/FVUNrp8UtPuDfvBZX8+sqA9fiCqOPucZWY+ypKovY5Ipi29Le7U51PozvBdUUv7G1w1Y3tLpTnQeCjl1QAHnr7RZXHbZxe6tFomVWXhqhxVUeG6mtWaSw0E34UoZ2R1dv6NqdDdeujLZkCDqCWWAUylw0t3XZtnirVZaWMJs8YPQjfcqSd9jvXlhnq7b1WGNbFy2u8txwbc0OnTfV9qp/u89smNqdDdeujLZkyAcEs8BomYwXG2z71mbb2N1spZEImYwU+WG1pXm1VRZdVG9HllYFYgUwZLat2RF71eX8PemHdmW0JUO+IJgFRslkbGvvsj2rI7ZHVaV19PSRyQjoYU39SO82Wa2twjdVINs7FH7YgRlp+dxsL6vr13ZlydqShel1Qf4gmAVGy2TUVlp5rN06CmmwngoOa4Zzh8JPOzAY3K4s+eQ21Q6r5IK2ZAiq8KUogDRnMjD6YU3tBBT9fWdA53W5dgaYSZ2dHQrtQEyqKHY7XzrVeV2ufw/y42FsorSoQ54jmAWSoNn+xLAzEL4dCnZggteuLGxtyZC/CGZDgt6CqYmSyZgQdgbCt0PBDoz/25Wp/ZhKpNo6e6w3FnOnOh+GtmTIb/wShwA1bOPPZOjwaGVtNGkmQ+18yGQkF03YGVBmbjB2BvKvTpK6zGC2KwtDWzLkP35J8hyTcCbeeH1N8w7bszpmvQUx29kTC02D9bTtDJRGBmTq2BnIjmiWdyiy/XhIb7syIMj4VgmA8ba5obdgmjIZf+8zu6mt3UoiETIZE1yFyW87A35pI5WLHYpFM6vd/6v8KDrB584OTDCErS0ZwoFgNo9LBOgtOHF6jeccXmEvvVZikcoaVgDLs8Oa+VyCM9oOhSZnbWnvthsfezUtzz1IOzAA8gvBbB6XCFDDlh768a2tKrW6uupQNtzP18OaYSjBGW6HYmZNmVvWV0FmOp97EHZgAOQfglmfSkeJQJQaNoTwsGaysoEwl+AM3qEoLy6yX72w0Ta2dGbkuft5BwZAfiKK8al0lAhQw4awGa5s4Kh9a60qxCU4iTsUas23rrk9o8+dukwA2cQx0zzu00lvQYTJSKtP3fHU67Ypob9pmPvghvm5A8hPBLM+FU1T036vhm3RzBrb3tFj6ze3u1NlZPOhJhAYy+pT29q77S8btvcvFhIN8aIY0RA/dwD5iW8rn0pniQA1bMh3o5UN1FeX2aaWVje5afeplaEuwQnzcweQnwhmfSrdbW6oYUM+G0vnjvaOWP+h87F8vpbuU5eXO4C00AKQbwhmfYw2N8DYRMfQuaO4qHDAofORPl971VfZo6ua8rL/rPDdAiCfEMz6HCUCCJPxrsY12qHzhtZOO6iu3GbWlI/6+VLge+fTmes/65cVx/huAZAvCGYDgBIBhMFEVuMay6Hzd+5elTRQS/x8KdC8ZcXajPWf9duKY3y3AMgHBLMA8mI1rpEOnR+5T61Vxd9uzTX+iWSl9sIb2+13rza7IDSVTGYYVhwDgFwgmAWQU+lcjWu4Q+dmcWtq2jmhiWTavlcaW+3NbTvt+0+us9rKsjFnVcO04hgAZBt9ZgHkVCqrcaVy6HxhfbU7TSU4jA7Tg1VB6PNvbLeGli4rKy6yPaa+vSCDsq3KumbzOQIA3kYwCyCn/LQilTeRTHW2mjgmOlWgqQA3Umg2vbrMBbLeggwKdJVV9RZk8PtzBIB8QzALIKeiPlqRKtkS0Ns6uq2prdP6+mJWURqxebXR/uzqWLOqfnqOAJBvCGYB5FSybOjgFanm1+2abJUNQ5aA3tJund19Vj+p3A6YPcmmREtTzqr67TkCQD4hDQAgp/y4IlXiRDLVuv74mQ2uR211efG4sqp+fI4AkC/IzALIuSHZ0M3t7lRttXLVssqbSPa+BbW236xJbuGFiWRV/fgcASAf+CIze/PNN9s3vvENa2hosP3339++/e1v28EHH5z0urfddpv993//t61cudKdP/DAA+3qq68e9vpAvvPLilL5uiJVOrOqfn2OABBkOQ9m77nnHjvvvPPs1ltvtUMOOcRuvPFGW7Zsmb388stWV1c35PorVqyw4447zpYsWWJlZWX29a9/3Y466ihbtWqVzZo1KyfPAcgVv60ola8rUo20IIMC2VRea78+RwAIqoL44ONmWaYA9t3vfrfddNNN7nwsFrPZs2fb2WefbRdccMGot+/r67PJkye725900kmjXr+1tdVqamqspaXFqqurLdP0fJqamlxgXlhIVUcQ+XUMh64oFXGz5b1sIYeu0z+GqWbB8yVr7gd+/RxibBi/4ItleQxTiddympnt7u62Z5991i688ML+y/QCLV261J5++ukx3UdHR4f19PTYlClTkv57V1eX+0t8cbxB0V+m6TG0v5CNx0J4xtCtKPVig21r77IFtQkrSpVGrLI2amuad9gjKxtszvtSWzQgX6VzDGdNKks4p/tMng9Y29xmj6xssrWbd7huB8rkzptWaUctqrN5texk5MPnEGPH+AVfLMtjmMrj5DSY3bx5s8usTp8+fcDlOr969eox3cf5559vM2fOdAFwMtdcc41dccUVQy5vbm62zs5Oy8ZgaK9CbwD2RoPJj2PY3NZl27c2257VESuPtQ/59z2rY7ZtS7O99FqJ1VYNbCUVRtkew03bd9pjLzVaW2ev7R4tsdJIoXX19lhjU6P9/KkttnTv6TaDNlyB/xxi7Bi/4ItleQzb2kZeWdFXNbMTce2119rdd9/t6mhVP5uMsr6qyU3MzKqMoba2NmtlBsqa6fH4AAeTH8dwW7zVNnY32x5VldaRJPPaWxCzTW3tFqmssbq6zL/P/S6bY6hM7c9WrbPX2iM2v3aSxQsKzO02R8wmlcZd1vzpjb322fnaFrLmQf4cYuwYv+CLZXkMh4vrfBfMTps2zYqKiqyxsXHA5TpfX18/4m2vv/56F8w+9thjtt9++w17vdLSUvc3mAYiWx8oDX42H2+sqOcL7hhWlirbF7GOnj63rOpgO3tiVhKJuOtlYpuD+N7J1hi+tb3D1m5ut/qacisY9FjaBl2+prndNrV2MREs4J9DpIbxC76CLI5hKo+R02C2pKTEtdZ6/PHH7ZhjjumP/HX+rLPOGvZ21113nV111VX28MMP20EHHZTFLc4f+TYLPmy8FaVWbmyxytJIf81sYu9TzbTPxIpSvHdGpgBfr0tFSfLXXm291A1hpBXDAAAWnDIDlQCcfPLJLihVr1i15mpvb7fly5e7f1eHArXcUu2rqBXXpZdeanfddZfNmTPH9aaVyspK94fxzIIvd7PgFRiplyaz4P0vVytK8d4ZXbQk4gJ8vS5Js+ZjWDEMADB2Of82PfbYY91kLAWoCkwPOOAAe+ihh/onhW3YsGFAqvmWW25xXRA++clPDrifyy67zC6//PKsb3/QuFnwKxtdMLKgLmEWfFmxy/ApMHpkVaPNnVbp+8PGYZfO3qdjwXvH/1lzAAijnAezopKC4coKNLkr0fr167O0VfnJW2teWbXEH1nReV2+pmmHux71fP6XzRWleO/4O2sOAGHli2AW2UM9X/7J1opSvHf8mzUHgDAjmPWRbMwQ1/1Sz4fxiPLe8W3WHADCjF8dn8jWDPGJ1PMFsR1TUPnxtaYW1L9ZcwAIM4JZH8jmDPHx1vPRjil7/PpaUwsKAPAjgtkcy8UM8VTr+YLcjsmPGc6R+P21phYUAOA3BLM5lqsZ4mOt5wtyOya/ZjiHE5TXmlpQAICfEMxmODh5c1uHNW3rsO7iDtttcnTID34uZ4iPpZ4vqO2Y/J7hDPprTS0oAMAvCGYznhVss6p4u7UVtNi82qohWcGoz2eIB7EdU1AynPnwWgMAkGtvL62FtGcFlQWcVF5s9dVl7lTndbn+ffAMcU2g0YzwRN4M8fl1lTmbIR5NCLaTyXWwPdEMZ6aC6Te2dtjqhlZ3qvNjEQ3gaw0AQK7xq5jprKD2GPq6dmUFy4qHZAX9PkM8iO2YcpnhnEidbhBfawAAco3MrA+ygt4M8UUza2x7R4+t39zuThW45Lq20wu2FVQr2G7r7LHeWMyd6nyug+1kojnKcA7IyFcUux0WnSbLyOfytR5v5hgAAD8iM+uTrKCfZ4insx1TNlpl5SLDma463Uy3vgpahwcAAEZDMJtm0QlM6PLzDPF0BNvZCqRyUbqRzk4EmdqxCWKHBwAARkMwm2ZDsoIJ/xb0useJBNvZDqSy3dw/3XW66d6xCWqHBwAARkMwm2ZDsoLVpVYaibu6x02tXb6sMc20XAVS2Szd8HuLtSD1sAUAIBVMAMuAARO6dvZYU2unO/XDhK6wtcryMpwL66vdaaZ2IvzeYu3tzHFk2MxxV28fPWwBAIFDZjZDvKzgm9varampyerq6pKuABYGYVgMwO8t1qI+zxwDADBeZGYzSIHLbpMr+v/CGMhKNCSLAfi5xZrfM8dIDe3VAOBtwY4eEAhhWgzAry3W/J45xtjRXg0ABiKYRcblayA1XM9cv7ZYG6nDw9J96tz/K9MX9UkAjqForwYAQxHMIiuy3Sor04KaHUuWOdZOxaOrgvdcwob2agCQHMEsLOyH4MOWHUvMHOu53Pl0cJ+LX2RjZTvaqwFAcgSzyCq/HoIPY3Ysn55LGLL0YegKAgDjQTcDICA9c9M94z1Iz8XvWXplsidVFLvAX6c6r8v17+kSDUlXEABIFd96QB5mx8aSLQzKc/GrbGe2w9QVBABSQWYWSEE0ANmxsWYLg/Bc/CzbmW2vK4i6fyhQ1hLZvbGYO9X5oHYFAYCJIpgF8mjxgcHZQmUJiwoL3KnO63JlC3U9vz8Xv8vFEsF+XpgDAHKFlAuQw5656Z4Fn+qM93zs/5st0RwtEZwvXUEAIF0IZoEc9czNxCz4VOtg863/bzblsoY16F1BACCdCGbztCclMmui2bFM9aodT7aQTN/45OvKdgAQNASzWRbUlaOQvuxYJmfBjzdbmI5MXy520nK9Y0hmGwByj2A2i4K+chT8v5JTrrKFudhJ88uOIZltAMgtgtksSXc2briMVK4zVRhdpvu7ZjtbON6dtIm8V/22Y0gNKwDkDsFslugHNl3ZuOEyUgtnVNnqTW05z1RhZNEszILPVrZwvDtpE8mqsgwvACARwWzAsnHDZaT+uG6L3f/8W+4y/cDnOlOF3M+Cz0a2cDwlExPNqmayTAMAEDwsmpAl0TSstjRcQ3wFRN5KQL19MXd+uEb5yL18Wskp1YUDUlnUIV2PCQDIbwSzWTKzZuKrLQ2XkWrr7LVtHT02NVriTnU+k8tqYuLyZSWnaIo7aelYAjbVxwQA5De+7bMkHbPMhytV6O6LuYxsTUWxte7scefTOaFoPJiIFo5Z8KmWTKSj3CaXixUAAPyHYDaLJjrLPDrMxKGSokKLFBW6wLiosNCdz1amKlnQum7zjoy2TMqnQDnos+BT3UmLpmHyG4sVAAASEcwGKBs3XEaqqixikyuKbd3mdps7LerOTyRTNdZgMdmM9Enlxda0o8v6YvGMtEzyS29RjG8nLV1ZVRYrAAB4CGYDlI0bKSMVKSx0mS5laHd09Y47UzXWYDHZjPT2rh57cu1m97iH7zmtP/OWrpZJfustitR30tKZVc2HMg0AwMQRzAbMcBmpQ+dNtb3q3+4zO55M1ViDxeH6fJoVuBmFRQVm6zZ32JRoaf+/TbRlEr1F82cnLZ1Z1aCXaQAAJo5gNoBGykgdsVfduDJVqQSLw81IdxPRYnGrrih296OuCtXlxWmZiEZv0fxCVhUAkC4EswE1XEZqvJmqVILF4WakexPRCqzA+mJ9Q7oqTGQiWqaXgEX2kVUFAKQDfWaRciP66DB9PndNRCtx/VILCwoGdFUYay/d4Qz3mB56iwIAEE4Es3CiKQSL3oz0wQtAKIM7rzZqsXjcdl0cT9vKVsM9ZjoCZQAAEFyksZByy6SRZqRvae+2/WdPsrqqUpehbWrrSkvLJHqL5ld/XQAA0oVgFuMKFkebke5NFEtn4JXOWfBBCwzprwsAQHIEsxh3sDjajPRMTO5Jxyz4oAWG9NcFAGB4BLOYULCYixnpE3nMoAWG9NcFAGBkBLMITcukIAaG9NcFAGBkdDNAaKQSGAaxZRoAAGFEMIvQCGJgGKW/LgAAIyKYRV6VEbyxtcNWN7S6U51PFA1gYEh/XQAARuafX21gAsbSoSCVXrp+QX9dAABGRjCLwBtrh4KgBobp7K8bVEHrCwwAyB6CWYSqQ0FQA8N09NcNqqD1BQYAZBfBLAJtPK2rghoY5mvLtHzqCwwAyD6CWeRJh4Lkda4qI1D2dXCHgjAGhkETxL7AAIDso5sBAi0awA4FyN++wACA7COYRaAFpXXVaG3DkB99gQEAIQ1mb775ZpszZ46VlZXZIYccYs8888yI17/33ntt4cKF7vqLFy+2Bx98MGvbCn/xOhSoE4EOO7d19lhvLOZOdd4PHQpU93nLirX2H4++Yt96/FV3qvO6HMOLknUHAAQhmL3nnnvsvPPOs8suu8yee+4523///W3ZsmXW1NSU9PpPPfWUHXfccXbqqafaX/7yFzvmmGPc38qVK7O+7fAHr0PBopk1tr2jx9Zvbnen6lCQ6wlC3gQmTViaVFHs6jt1qvO6nIA2+Fl3AEBuFcQH/0pkmTKx7373u+2mm25y52OxmM2ePdvOPvtsu+CCC4Zc/9hjj7X29nb79a9/3X/ZP/zDP9gBBxxgt95666iP19raajU1NdbS0mLV1dWWaXo+Cszr6uqssDDn+w55LVO9SMc7htoeZWAVuCZOYBJ97JQ5VsD9+cPnMYFpjN0MBvcFHuvOCp/D4GMMg43xC75YlscwlXgtp8fnuru77dlnn7ULL7yw/zK9QEuXLrWnn3466W10uTK5iZTJ/fnPf570+l1dXe4v8cXxBkV/mabHUOCSjceCsnllCef0usdzNoZvbuuwtc1tNqO61FyomrDfqPO6XMHam9vabbfJdFZIZu60qJ2yZHd7ZGWTrd28w5pa+6xEfYFnVtuR+9a5fx/LuPA5DD7GMNgYv+CLZXkMU3mcnAazmzdvtr6+Pps+ffqAy3V+9erVSW/T0NCQ9Pq6PJlrrrnGrrjiiiGXNzc3W2dnp2VjMLRXoTcAe6P+p+B3S3u3m1ikesyp0RIXFI9nDJu2dVhVvN1qI2VW2Pf2DpWnNBK3WLzT7emW9BDMDkd513/cu9K2tJcMGJfC+E5rahpbJwM+h8HHGAYb4xd8sSyPYVvb2Mvw8n7mhLK+iZlcZWZVxlBbW5u1MgMdXtbj8QH2N2VRH3lpVwbQC5rmTau0I/edZpMmpT6G3cUd1lbQYkW9xa436mCapNZaUOQO2dSRmR1V/QRuy+cw+BjDYGP8gi+W5THUJP9ABLPTpk2zoqIia2xsHHC5ztfXJ//p0uWpXL+0tNT9DaaByNYHSoOfzcdD6nS4/46nNiSsNBXZtdLUplbb2LrTPrl3lU2fntoY7jY5avNqq1zNbGVZ8ZCa2U2tXa5mVtejZjbz+BwGH2MYbIxf8BVkcQxTeYycvqNKSkrswAMPtMcff3xA5K/zhx56aNLb6PLE68ujjz467PWBVFeaUha1qLDAner8tvZu+8uG7SnX3wahbRgAAEGX890jlQDcdtttduedd9pLL71kp59+uutWsHz5cvfvJ5100oAJYuecc4499NBD9s1vftPV1V5++eX25z//2c4666wcPgvk80pT9dVltqllp21s2ZlXbcMAAMgHOa+ZVastTca69NJL3SQutdhSsOpN8tqwYcOAVPOSJUvsrrvusksuucQuuugiW7BggetksGjRohw+C+THSlPJ+5WqHVR7R2zcK00pYJ37/sqMtA0DACDsch7MirKqw2VWV6xYMeSyT33qU+4PSIdowkpTySZqqa9pcVHhhFaaUuA6ewqTvAAAyLsyA8DvK001tHbajJpym1nDSlMAAPgNwSxCb7SJWpOjJfbO3SdRFgAAgA/5oswAyDVvopa6GmgyWGNrp+szq4laR+5Ta1Xx1Cd/AQCAzCOYBUaZqKUVwMa60hQAAMgugllglIlaqfaXBQAA2UPNLAAAAAKLYBYAAACBRTALAACAwCKYBQAAQGAxAQy+pYlXLAELAABGQjALX1rT1Nbf87Wzt88tN6tVurS4gVpoAQAACMEsfBnI/uAP621re7fNqCmzipJy6+jutZUbW2xjy063uAEBLQAAEGpm4bvSAmVkFcguqKu0qrJiKyoscKc6r8sfWdVI71cAAOAQzMJXVCOr0gJlZAsKBtbH6rwuX9O0w10PAACAYBa+osleqpGtKEleAVNeUmRdvX3uegAAAASz8JVoScRN9lKNbDI7u/usNFLkrgcAAEAwi5xT/esbWztsdUOrxeJxm1sbtU0tnRaPD6yL1XldPr+u0rXpAgAAIL0F37XgmlS+a9LXq027amdVWqCMrALZKdESO2rf6fSbBQAADsEsfNeCS0Grglldtr2jxxpbO11pweJZNS6QpS0XAADwEMzCFy24vM4FasFVWRpxWdmp0RI7Zckc6+jpYwUwAACQFMEsfNuCa21zu/v/hfXVOdtOAADgb0wAQ07QggsAAKQDwSxyIkoLLgAAkAYEs8gJ1b/Oq62kBRcAAJgQglnkhCZyLVs03bXa0mSvts4e643F3KnO04ILAACMBcEsckYttpYfNscWzaxxLbjWb253p2rBpctpwQUAAEZDQSJySgHr3PdXuu4GmuwVpQUXAABIAcEsck6B6+wpFbneDAAAEECUGQAAACCwCGYBAAAQWASzAAAACCyCWQAAAAQWwSwAAAACi2AWAAAAgUUwCwAAgMAimAUAAEBgEcwCAAAgsAhmAQAAEFihW842Ho+709bW1qw8XiwWs7a2NisrK7PCQvYdgogxDD7GMPgYw2Bj/IIvluUx9OI0L24bSeiCWQ2EzJ49O9ebAgAAgFHitpqampGuYgXxsYS8ebZnsXHjRquqqrKCgoKs7FkocH7jjTesuro644+H9GMMg48xDD7GMNgYv+BrzfIYKjxVIDtz5sxRM8Ghy8zqBdltt92y/rgaeD7AwcYYBh9jGHyMYbAxfsFXncUxHC0j66FwBQAAAIFFMAsAAIDAIpjNsNLSUrvsssvcKYKJMQw+xjD4GMNgY/yCr9THYxi6CWAAAADIH2RmAQAAEFgEswAAAAgsglkAAAAEFsEsAAAAAotgNg1uvvlmmzNnjluv+JBDDrFnnnlmxOvfe++9tnDhQnf9xYsX24MPPpi1bcXEx/C2226z9773vTZ58mT3t3Tp0lHHHP77HHruvvtutxrgMccck/FtRHrHcPv27XbmmWfajBkz3AzrPffck+/TAI3fjTfeaHvttZeVl5e7laXOPfdc6+zszNr2YqDf/e53dvTRR7sVt/Sd+POf/9xGs2LFCnvXu97lPn/z58+3O+64w3JC3QwwfnfffXe8pKQkfvvtt8dXrVoVP+200+KTJk2KNzY2Jr3+H/7wh3hRUVH8uuuui//tb3+LX3LJJfHi4uL4iy++mPVtx/jG8Pjjj4/ffPPN8b/85S/xl156KX7KKafEa2pq4m+++WbWtx3jG0PPa6+9Fp81a1b8ve99b/zjH/941rYXEx/Drq6u+EEHHRT/8Ic/HH/yySfdWK5YsSL+/PPPZ33bkfr4/ehHP4qXlpa6U43dww8/HJ8xY0b83HPPzfq2Y5cHH3wwfvHFF8fvu+8+dbmK33///fGRrFu3Ll5RURE/77zzXDzz7W9/28U3Dz30UDzbCGYn6OCDD46feeaZ/ef7+vriM2fOjF9zzTVJr//P//zP8Y985CMDLjvkkEPin/vc5zK+rUjPGA7W29sbr6qqit95550Z3Eqkeww1bkuWLIn/13/9V/zkk08mmA3YGN5yyy3xuXPnxru7u7O4lUjX+Om6H/jABwZcpqDosMMOy/i2YnRjCWa//OUvx/fdd98Blx177LHxZcuWxbONMoMJ6O7utmeffdYdZvYUFha6808//XTS2+jyxOvLsmXLhr0+/DeGg3V0dFhPT49NmTIlg1uKdI/hV7/6Vaurq7NTTz01S1uKdI7hL3/5Szv00ENdmcH06dNt0aJFdvXVV1tfX18WtxzjHb8lS5a423ilCOvWrXMlIh/+8Ieztt2YGD/FM5GsP2Ie2bx5s/vi1BdpIp1fvXp10ts0NDQkvb4uRzDGcLDzzz/f1RgN/lDDv2P45JNP2ve//317/vnns7SVSPcYKvh54okn7IQTTnBB0Jo1a+yMM85wO5ZapQj+Hr/jjz/e3e4973mPjhBbb2+vff7zn7eLLrooS1uNiRounmltbbWdO3e6WuhsITMLTMC1117rJhDdf//9btID/K+trc1OPPFEN5Fv2rRpud4cjFMsFnOZ9e9973t24IEH2rHHHmsXX3yx3XrrrbneNIyBJg4pk/6d73zHnnvuObvvvvvsgQcesK997Wu53jQEEJnZCdAPYVFRkTU2Ng64XOfr6+uT3kaXp3J9+G8MPddff70LZh977DHbb7/9MrylSNcYrl271tavX+9m7SYGRhKJROzll1+2efPmZWHLMZHPoToYFBcXu9t59t57b5ct0mHvkpKSjG83xj9+X/nKV9xO5Wc+8xl3Xp192tvb7bOf/azbKVGZAvytfph4prq6OqtZWeHdMgH6slRG4PHHHx/wo6jzquVKRpcnXl8effTRYa8P/42hXHfddS6D8NBDD9lBBx2Upa1FOsZQbfFefPFFV2Lg/X3sYx+zI444wv2/WgTB/5/Dww47zJUWeDsi8sorr7ggl0DW/+OnuQaDA1Zvx2TX/CP43aF+imeyPuUsD9uRqL3IHXfc4VpTfPazn3XtSBoaGty/n3jiifELLrhgQGuuSCQSv/76611bp8suu4zWXAEbw2uvvda1oPnpT38a37RpU/9fW1tbDp9FuKU6hoPRzSB4Y7hhwwbXReSss86Kv/zyy/Ff//rX8bq6uviVV16Zw2cRXqmOn377NH4//vGPXYunRx55JD5v3jzX8Qe50dbW5lpO6k/h4Q033OD+//XXX3f/rvHTOA5uzfWlL33JxTNqWUlrrgBTb7Xdd9/dBThqT/LHP/6x/98OP/xw90OZ6Cc/+Ul8zz33dNdXW4sHHnggB1uN8Y7hO97xDvdBH/ynL2cE53OYiGA2mGP41FNPudaGCqLUpuuqq65yLdfg//Hr6emJX3755S6ALSsri8+ePTt+xhlnxLdt25ajrcdvf/vbpL9t3rjpVOM4+DYHHHCAG3N9Bn/wgx/kZNsL9J/s54MBAACAiaNmFgAAAIFFMAsAAIDAIpgFAABAYBHMAgAAILAIZgEAABBYBLMAAAAILIJZAAAABBbBLAAAAAKLYBYAAACBRTALAAHU19dnS5YssU984hMDLm9pabHZs2fbxRdfnLNtA4BsYjlbAAioV155xQ444AC77bbb7IQTTnCXnXTSSfbCCy/Yn/70JyspKcn1JgJAxhHMAkCAfetb37LLL7/cVq1aZc8884x96lOfcoHs/vvvn+tNA4CsIJgFgADTV/gHPvABKyoqshdffNHOPvtsu+SSS3K9WQCQNQSzABBwq1evtr333tsWL15szz33nEUikVxvEgBkDRPAACDgbr/9dquoqLDXXnvN3nzzzVxvDgBkFZlZAAiwp556yg4//HB75JFH7Morr3SXPfbYY1ZQUJDrTQOArCAzCwAB1dHRYaeccoqdfvrpdsQRR9j3v/99Nwns1ltvzfWmAUDWkJkFgIA655xz7MEHH3StuFRmIN/97nfti1/8opsMNmfOnFxvIgBkHMEsAATQ//7v/9oHP/hBW7Fihb3nPe8Z8G/Lli2z3t5eyg0AhALBLAAAAAKLmlkAAAAEFsEsAAAAAotgFgAAAIFFMAsAAIDAIpgFAABAYBHMAgAAILAIZgEAABBYBLMAAAAILIJZAAAABBbBLAAAAAKLYBYAAAAWVP8fCgIN2bNvocUAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Generate synthetic data: y = 0.8 * x^2 + 0.1 + noise\n",
    "X = np.linspace(0, 1, 100)[:, None]\n",
    "Y = 0.8 * X ** 2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)\n",
    "\n",
    "def dataset(batch_size):\n",
    "    \"\"\"Generator that yields random batches from the dataset.\"\"\"\n",
    "    while True:\n",
    "        idx = np.random.choice(len(X), size=batch_size)\n",
    "        yield X[idx], Y[idx]\n",
    "\n",
    "# Visualize the data\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.scatter(X, Y, alpha=0.5, label='Data')\n",
    "plt.xlabel('X')\n",
    "plt.ylabel('Y')\n",
    "plt.title('Synthetic Dataset for Polynomial Regression')\n",
    "plt.legend()\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a7b8c9d0e1",
   "metadata": {},
   "source": [
    "## Building the Model\n",
    "\n",
    "### Step 1: Define Basic Components\n",
    "\n",
    "First, let's create a simple `Linear` layer and a custom state type to track function calls:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a7b8c9d0e1f2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.298818Z",
     "start_time": "2025-10-11T10:20:19.294038Z"
    }
   },
   "outputs": [],
   "source": [
    "class Linear(brainstate.nn.Module):\n",
    "    \"\"\"A simple linear layer: y = x @ w + b\"\"\"\n",
    "    \n",
    "    def __init__(self, din: int, dout: int):\n",
    "        super().__init__()\n",
    "        # Initialize weights and biases as trainable parameters\n",
    "        self.w = brainstate.ParamState(brainstate.random.rand(din, dout))\n",
    "        self.b = brainstate.ParamState(jnp.zeros((dout,)))\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        return x @ self.w.value + self.b.value\n",
    "\n",
    "\n",
    "class Count(brainstate.State):\n",
    "    \"\"\"Custom state type for tracking function calls.\"\"\"\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c9d0e1f2a3",
   "metadata": {},
   "source": [
    "### Step 2: Build the MLP Model\n",
    "\n",
    "Now let's create a multi-layer perceptron (MLP) with a call counter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c9d0e1f2a3b4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.327704Z",
     "start_time": "2025-10-11T10:20:19.323025Z"
    }
   },
   "outputs": [],
   "source": [
    "class MLP(brainstate.graph.Node):\n",
    "    \"\"\"Multi-layer perceptron with call counting.\"\"\"\n",
    "    \n",
    "    def __init__(self, din, dhidden, dout):\n",
    "        # Custom state to count how many times the model is called\n",
    "        self.count = Count(jnp.array(0))\n",
    "        \n",
    "        # Two linear layers\n",
    "        self.linear1 = Linear(din, dhidden)\n",
    "        self.linear2 = Linear(dhidden, dout)\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Increment call counter\n",
    "        self.count.value += 1\n",
    "        \n",
    "        # Forward pass\n",
    "        x = self.linear1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0e1f2a3b4c5",
   "metadata": {},
   "source": [
    "## Understanding Graph Operations\n",
    "\n",
    "### What are Graph Operations?\n",
    "\n",
    "BrainState models are represented as computational graphs where:\n",
    "- **Nodes** represent modules or components\n",
    "- **States** are the mutable variables within these nodes\n",
    "\n",
    "Graph operations allow you to:\n",
    "1. **Split** a model into its graph definition and separate state pytrees\n",
    "2. **Merge** a graph definition with state pytrees to reconstruct the model\n",
    "\n",
    "This is essential for functional programming with JAX, as it allows you to:\n",
    "- Pass states as explicit function arguments\n",
    "- Apply JAX transformations (jit, grad, vmap) to functions operating on states\n",
    "- Manage different types of states independently (e.g., parameters vs. counters)\n",
    "\n",
    "### Splitting the Model\n",
    "\n",
    "Let's create a model and split it into its components:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e1f2a3b4c5d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.775901Z",
     "start_time": "2025-10-11T10:20:19.349483Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph definition (model structure):\n",
      "NodeDef(\n",
      "  type=MLP,\n",
      "  index=0,\n",
      "  attributes=('count', 'linear1', 'linear2'),\n",
      "  subgraphs={\n",
      "    'linear1': NodeDef(\n",
      "      type=Linear,\n",
      "      index=2,\n",
      "      attributes=('_in_size', '_name', '_out_size', 'b', 'w'),\n",
      "      subgraphs={\n",
      "        '_in_size': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        ),\n",
      "        '_name': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        ),\n",
      "        '_out_size': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        )\n",
      "      },\n",
      "      static_fields={},\n",
      "      leaves={\n",
      "        'b': NodeRef(\n",
      "          type=ParamState,\n",
      "          index=3\n",
      "        ),\n",
      "        'w': NodeRef(\n",
      "          type=ParamState,\n",
      "          index=4\n",
      "        )\n",
      "      },\n",
      "      metadata=(<class '__main__.Linear'>,),\n",
      "      index_mapping=None\n",
      "    ),\n",
      "    'linear2': NodeDef(\n",
      "      type=Linear,\n",
      "      index=5,\n",
      "      attributes=('_in_size', '_name', '_out_size', 'b', 'w'),\n",
      "      subgraphs={\n",
      "        '_in_size': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        ),\n",
      "        '_name': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        ),\n",
      "        '_out_size': NodeDef(\n",
      "          type=PytreeType,\n",
      "          index=-1,\n",
      "          attributes=(),\n",
      "          subgraphs={},\n",
      "          static_fields={},\n",
      "          leaves={},\n",
      "          metadata=PyTreeDef(None),\n",
      "          index_mapping=None\n",
      "        )\n",
      "      },\n",
      "      static_fields={},\n",
      "      leaves={\n",
      "        'b': NodeRef(\n",
      "          type=ParamState,\n",
      "          index=6\n",
      "        ),\n",
      "        'w': NodeRef(\n",
      "          type=ParamState,\n",
      "          index=7\n",
      "        )\n",
      "      },\n",
      "      metadata=(<class '__main__.Linear'>,),\n",
      "      index_mapping=None\n",
      "    )\n",
      "  },\n",
      "  static_fields={},\n",
      "  leaves={\n",
      "    'count': NodeRef(\n",
      "      type=Count,\n",
      "      index=1\n",
      "    )\n",
      "  },\n",
      "  metadata=(<class '__main__.MLP'>,),\n",
      "  index_mapping=None\n",
      ")\n",
      "\n",
      "Parameters (trainable weights):\n",
      "{\n",
      "  'linear1': {\n",
      "    'b': TreefyState(\n",
      "      type=<class 'brainstate.ParamState'>,\n",
      "      value=(32,),\n",
      "      tag=None\n",
      "    ),\n",
      "    'w': TreefyState(\n",
      "      type=<class 'brainstate.ParamState'>,\n",
      "      value=(1, 32),\n",
      "      tag=None\n",
      "    )\n",
      "  },\n",
      "  'linear2': {\n",
      "    'b': TreefyState(\n",
      "      type=<class 'brainstate.ParamState'>,\n",
      "      value=(1,),\n",
      "      tag=None\n",
      "    ),\n",
      "    'w': TreefyState(\n",
      "      type=<class 'brainstate.ParamState'>,\n",
      "      value=(32, 1),\n",
      "      tag=None\n",
      "    )\n",
      "  }\n",
      "}\n",
      "\n",
      "Counters:\n",
      "{\n",
      "  'count': TreefyState(\n",
      "    type=<class '__main__.Count'>,\n",
      "    value=Array(0, dtype=int32, weak_type=True),\n",
      "    tag=None\n",
      "  )\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "# Create the model\n",
    "model_initial = MLP(din=1, dhidden=32, dout=1)\n",
    "\n",
    "# Split the model into graph definition and states\n",
    "graphdef, params_, counts_ = brainstate.graph.treefy_split(\n",
    "    model_initial, \n",
    "    brainstate.ParamState,  # Split out trainable parameters\n",
    "    Count                    # Split out call counters\n",
    ")\n",
    "\n",
    "print(\"Graph definition (model structure):\")\n",
    "print(graphdef)\n",
    "print(\"\\nParameters (trainable weights):\")\n",
    "print(jax.tree.map(jnp.shape, params_))\n",
    "print(\"\\nCounters:\")\n",
    "print(counts_)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a3b4c5d6e7",
   "metadata": {},
   "source": [
    "**Key Points:**\n",
    "- `graphdef`: Contains the model structure (immutable)\n",
    "- `params_`: PyTree of trainable parameters (`ParamState`)\n",
    "- `counts_`: PyTree of counters (`Count` state)\n",
    "\n",
    "This separation is crucial because:\n",
    "1. Only `params_` needs gradients during training\n",
    "2. `counts_` needs to be updated but not differentiated\n",
    "3. `graphdef` remains constant throughout training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3b4c5d6e7f8",
   "metadata": {},
   "source": [
    "## Functional Training Loop\n",
    "\n",
    "### Step 1: Define Training Step\n",
    "\n",
    "With separated states, we can create a pure functional training step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b4c5d6e7f8a9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.803377Z",
     "start_time": "2025-10-11T10:20:19.798078Z"
    }
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def train_step(params, counts, batch):\n",
    "    \"\"\"Perform one training step with explicit state management.\"\"\"\n",
    "    x, y = batch\n",
    "    \n",
    "    def loss_fn(params):\n",
    "        # Merge graph definition with states to reconstruct the model\n",
    "        model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "        \n",
    "        # Forward pass\n",
    "        y_pred = model(x)\n",
    "        \n",
    "        # Extract updated counters (model was called, so count changed)\n",
    "        new_counts = brainstate.graph.treefy_states(model, Count)\n",
    "        \n",
    "        # Compute loss\n",
    "        loss = jnp.mean((y - y_pred) ** 2)\n",
    "        \n",
    "        return loss, new_counts\n",
    "    \n",
    "    # Compute gradients with respect to parameters\n",
    "    grad, counts = jax.grad(loss_fn, has_aux=True)(params)\n",
    "    \n",
    "    # Simple SGD update: params = params - lr * grad\n",
    "    params = jax.tree.map(lambda w, g: w - 0.1 * g, params, grad)\n",
    "    \n",
    "    return params, counts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5d6e7f8a9b0",
   "metadata": {},
   "source": [
    "**Understanding the Training Step:**\n",
    "\n",
    "1. **Loss Function Definition:**\n",
    "   - Takes `params` as input (what we differentiate)\n",
    "   - Merges `graphdef`, `params`, and `counts` to reconstruct the model\n",
    "   - Computes predictions and loss\n",
    "   - Returns both loss and updated counts (auxiliary output)\n",
    "\n",
    "2. **Gradient Computation:**\n",
    "   - `jax.grad` computes gradients of loss w.r.t. parameters\n",
    "   - `has_aux=True` allows returning both gradients and auxiliary values (counts)\n",
    "\n",
    "3. **Parameter Update:**\n",
    "   - Simple gradient descent: new_params = old_params - learning_rate * gradients\n",
    "   - Uses `jax.tree.map` to apply the update to all parameters in the pytree\n",
    "\n",
    "### Step 2: Define Evaluation Step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6e7f8a9b0c1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:19.844478Z",
     "start_time": "2025-10-11T10:20:19.839915Z"
    }
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def eval_step(params, counts, batch):\n",
    "    \"\"\"Evaluate the model on a batch.\"\"\"\n",
    "    x, y = batch\n",
    "    \n",
    "    # Reconstruct model\n",
    "    model = brainstate.graph.treefy_merge(graphdef, params, counts)\n",
    "    \n",
    "    # Forward pass\n",
    "    y_pred = model(x)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = jnp.mean((y - y_pred) ** 2)\n",
    "    \n",
    "    return {'loss': loss}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8a9b0c1d2",
   "metadata": {},
   "source": [
    "### Step 3: Run Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f8a9b0c1d2e3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:20.741207Z",
     "start_time": "2025-10-11T10:20:19.875240Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training the model...\n",
      "\n",
      "Step:     0, Loss: 2.924491\n",
      "Step:  1000, Loss: 0.007898\n",
      "Step:  2000, Loss: 0.007857\n",
      "Step:  3000, Loss: 0.007903\n",
      "Step:  4000, Loss: 0.007962\n",
      "Step:  5000, Loss: 0.007840\n",
      "Step:  6000, Loss: 0.007831\n",
      "Step:  7000, Loss: 0.007867\n",
      "Step:  8000, Loss: 0.007872\n",
      "Step:  9000, Loss: 0.007837\n",
      "\n",
      "Training complete!\n"
     ]
    }
   ],
   "source": [
    "# Training parameters\n",
    "total_steps = 10_000\n",
    "\n",
    "# Training loop\n",
    "print(\"Training the model...\\n\")\n",
    "for step, batch in enumerate(dataset(32)):\n",
    "    # Update parameters and counters\n",
    "    params_, counts_ = train_step(params_, counts_, batch)\n",
    "    \n",
    "    # Log progress every 1000 steps\n",
    "    if step % 1000 == 0:\n",
    "        logs = eval_step(params_, counts_, (X, Y))\n",
    "        print(f\"Step: {step:5d}, Loss: {logs['loss']:.6f}\")\n",
    "    \n",
    "    # Stop after total_steps\n",
    "    if step >= total_steps - 1:\n",
    "        break\n",
    "\n",
    "print(\"\\nTraining complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b0c1d2e3f4",
   "metadata": {},
   "source": [
    "## Analyzing Results\n",
    "\n",
    "### Reconstruct the Final Model\n",
    "\n",
    "After training, we can merge the learned parameters back into a model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b0c1d2e3f4a5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:23.037388Z",
     "start_time": "2025-10-11T10:20:22.873575Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total model calls during training: 10000\n",
      "Final predictions shape: (100, 1)\n"
     ]
    }
   ],
   "source": [
    "# Reconstruct the trained model\n",
    "model = brainstate.graph.treefy_merge(graphdef, params_, counts_)\n",
    "\n",
    "# Check how many times the model was called during training\n",
    "print(f\"Total model calls during training: {model.count.value}\")\n",
    "\n",
    "# Make predictions on the full dataset\n",
    "y_pred = model(X)\n",
    "\n",
    "print(f\"Final predictions shape: {y_pred.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1d2e3f4a5b6",
   "metadata": {},
   "source": [
    "### Visualize Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d2e3f4a5b6c7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T10:20:23.175407Z",
     "start_time": "2025-10-11T10:20:23.042957Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAJOCAYAAACqS2TfAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAsqhJREFUeJzs3Qd4U2UXB/B/96ItdLH33nvKlo0Dt4KIAoooyhAUBygOUJYoCCiC43PhREQQAdlb9t5QdktLS+ke+Z5zr1ltupPmJvn/vidfyXtvktvktvbc877nuOl0Oh2IiIiIiIiIyOrcrf+URERERERERCQYdBMRERERERHZCINuIiIiIiIiIhth0E1ERERERERkIwy6iYiIiIiIiGyEQTcRERERERGRjTDoJiIiIiIiIrIRBt1ERERERERENsKgm4iIiIiIiMhGGHQTEbmgrl27ws3NTbk9+eSTcGXWfi/0zyW3L7/80irHSMUjn4Pp50Ilp1q1aob3/a233rL34RAR2QWDbiIiDdiwYYNZUGB6K1WqFBo0aIAXXngBZ8+etfehkg1IMGLps/f29kZERAS6dOmCjz76CCkpKfY+VLIRVzwHGJATkavwtPcBEBFR3hITE3Hs2DHltmTJEvz+++/o0aOHvQ/LaYwcORJ33XWX8u9GjRpBS9LT0xEdHa3cNm3ahF9//RX//PMPPDw87H1oDqV169aYMWMGHBHPASIix8egm4hIgx555BG0atUKaWlp2L59O1asWKGMJyUlYfDgwTh//jx8fHzsfZhO815rzWuvvYbSpUvj2rVr+OabbxAVFaWMS9D1559/4p577oEW3bp1C0FBQdCahg0bKjdH4qjnABER5cTp5UREGtSnTx+MHz9e+cP7jz/+wKBBgwzb5I/wrVu3mu2/Z88ePPHEE6hevTp8fX2VKemStX3ppZdw6dKlAr3m+vXrzaa1njx50mx7VlYWypUrZ9j+wQcfWFwvm5qaivfeew916tRRLgxUqlRJ+V5k3JJffvkF/fv3V55bptKWKVMGHTp0wKxZs5SLDPmtmf7f//6HZs2awc/PD7Vq1cKHH36o7JeRkYF3331XeU/kOOrXr49FixYVak23ZEcHDBigfC8hISHw8vJSAqE2bdoo36PMQrCFp59+GhMmTFDeg6+++sps29GjR3PsL5+NvA+9evVSpiLL+xgeHq68rytXrrT4GvL+TJ8+HbVr11ben5o1a2Lq1KlKZjW3denZP2v5fF5//XXUqFFDeW8mT55s2Fc+73nz5qFz587KeyfHVL58eTz00EPKhSRL5Pnl8wgLC1OeT86FunXrKhdG5s+fb7bvhQsXMGLECOX45bOX875ixYq44447MG7cOGVmSG7HnV1ycrJy3shj5TXlWMuWLYt+/frhxx9/zHc5iCz7kONr0qSJchzyGQwfPhw3b96Els+B5cuXK79r5HuV91sumMh5IOf8tGnTlOcsyFRw06nxsl9+5OdM9pXPUG/KlCkWP6PCfM5ERJqlIyIiu1u/fr1OfiXrb1988YXZ9nnz5plt//bbbw3bPvzwQ527u7vZdtNbcHCw8vymunTpYtg+ZMgQw3ijRo0M4xMmTDB7zD///GPY5uHhobty5YoyLsdq+nodO3a0eByDBw82e76MjAzdww8/nOtxy61+/fqG19Ez3d6yZUuLj5s0aZLu3nvvtbht8eLFBXovRGhoaJ7H17hxY11CQkKux5f9c8zNm2++afa4c+fOGbYdPHjQbNuiRYvMHpuUlKTr0aNHnsc5bty4HK/56KOPWtz37rvvzvV7yP5Zd+rUyez+6NGjlf2ioqJ0zZo1y/V45HydM2dOnu9B9lvZsmUN+16/fl0XHh6e5/4LFizI9bhNXb16VdewYcM8n+uBBx7Qpaen5/rzmts537lz5wJ9/vY4B7K/J5ZuycnJhv2rVq1qGJdjze3YZT9Tlh4nP2f5vXZRPmciIq3i9HIiIgeQPTMoWWH9VFPJ9qixHlClShU89thjuH37Nr744gslExkfH48HHngAp0+fVrJ4eRk1ahSeffZZ5d9ff/21ks2VDJj46aefDPtJdkyylpZs2bIF9913n1L87dtvv1Wmwgv59/vvv48KFSoo9yWrappFbNeunZKlk8yV/rXk35LllzWslkiGv3379ujZsyeWLl2KEydOKOPvvPOO8lWKT0mmVTLcMkNASHZ36NChKAjJ0nfr1g1Vq1ZV3jt5n8+dO6e8lmS5Dx06pGQ4X375ZVibvJYcs+laZMn06def640dOxZr165V/i3ZzUcffVTJCsqxyfsozzN79my0bNkSAwcOVPb7+eef8cMPPxieQzLV8rjIyEjlcyqozZs3o23btsr7L++HnH9ClkDs379f+XdgYKDyuvJeygyNv/76S8mgynHLEgrJWIoFCxYYnldqFkjGW57z4sWLyjkl2WjT2RGyxlnI5/LUU08hNDQUV65cwfHjx5XjKig5v44cOWK4/+CDDyrn7po1aww/d/J6cr6aZvJNyfHdeeedygyNZcuWKe+9/udzx44dyrmttXPA9P2WNe/ynDL7Qd7vnTt32jSDLMcnM3HkPdXPBpBzSH7+TVnzcyYisit7R/1ERJQzc/bII4/oZsyYoXvvvfdyZB4l46fPQJlmcwMDA5XMkN7KlSvNHicZ8fyyu7dv39aVLl3asO2XX34xZKXldbOPW8qYjRkzxrBt//79ZtuWL1+ujGdmZupCQkIM4+3bt1deQ+/ll182e9y+ffsM20zHGzRooEtLS1PGV69ebbatadOmhudcuHCh2bZbt27l+17oxcXFKe+lPMesWbOUz0UymPrHdO/e3Wx/a2S6Ld0qVKigW7NmjdnjYmJidJ6enoZ9lixZYrb9ueeeM2xr3ry5Ybx3796G8VKlSimZ6dyOJa9M9/333698lqYOHDhgto/MkDDVr18/w7b77rvPMB4UFGQYl+xzdmfOnDH8e/bs2YZ9R4wYkWNfOY+vXbuW63HryXllOi7nnZ6cO3Je6rfJ+ar/XrP/vMr3kZWVZfhMZCaIftvHH3+s0+I50KRJE8P49u3bcxyPZNpNP1trZroLsq0onzMRkVYx001EpEGSSZVbdrKeUdZ3ytfsGXDJPstaTr2+ffsqazr1mSLZd8yYMXm+bkBAgJIFlqyYkAzx/fffr2Tsrl+/rozJetu777471+d47rnnDP+W9bim9FktyUjHxsYaxh9//HGzasxDhgxRMtJ6cuyybju7hx9+2JCJz76WVI5b/5yyTjX7cUgGNi+SjZ04caLSpkkK2uWmoGvmi8PT01P57CSbakoykpKd1JPPLrcsvmSeZeaDv78//v333xzniZ5kE2V9bUFIzQF3d/PyMNnrDXTv3j3Xx2/bts3w706dOikFwoRkQSWDLtlaKYAmsw1kvb6eZMdl3a9c4/j000+xe/duJTst55tkz2V/Wadc2Bkkct7pybkj56V+Hzlf5byV2gCWKuDr1yHL+nX5GdH/vBRnXbctzwF5vw8ePGjIMsuMEXm/5X2U2SGNGzeGvVnrcyYisjcG3UREGifTSWV6swQvMo3UNPgwDVwt/fEpY/qgu6B//MsU8zlz5ihB599//61MNzWdBi6BiD7QtcQ0+M1eYV1fmMn0uC0de/b7uR27fqq6flptbtskYLF0HHn5+OOPC9RmKrcCccUhway8d9Ie6sCBA0pQJVPYJWB68803Dftlfx/zIoFLTEyMEnDFxcXlWKqQ2/281KtXL8dYYY5Jf27qpzvLRRSZji3Hmb34l2z7/vvvlSBfCtnJhaFJkyYpSyn27t2r3PQk6JVp1TJFPS/WOg+zX/AxPe8Lcq7Z4xyQqd1SAG7VqlXKeyjT6eWmJ0sz5CKIXIiz9Dy2/hkQ1vqciYjsjUE3EZEGyXrs7JW0LZGsmr6VkD6zZsp0LL/13HpS7VsqHkvVdAkYJNstf/ibZkLzYhqQW6oUrT/u3I7T0v3cjj2v4D97oF1YpjMNJID/7bfflGy7BPcS/Niy77NUrpZATqpXyzph/fpoCZTkooc+c5/9fZSLMqYXG7ILDg5WvkoFdgm+hP780dOvfS8ISwFZ9mN6++23lQtH+alcubKSVZbaA7t27cKpU6eUNcnSl14CTrnwI7M59OefZH2feeYZJUiXNdmyv6wXl683btxQstam1bELeh7KmmHT+0U5D3M777V0DkilcrmwITM15D2UbgVSFV3OcwnsN27cqMw20c96MJ3RYLq+Xsh7bivW+JyJiOyNQTcRkQPTF24S8oeoBFD6KeaSwTLNJMq+BfXCCy8oQbeQ4DIlJUX5txRikrZIxSXTQyVY0GfppA+xtAXSTwfP3iKpMMduLfqgVMhUVsm6CXkv9O+NrUmwKq2sZBqtkGnu0gZNLsoImYIt71lmZqYh+JP2bNlJMTuZGq3voS3fz+rVq5V/y1fJ4OoDSv1zF1X2z0qykTL9OjsJoEwzx5LNlSnNMpPDdDbHvffeq7S2EpLllKBbCmnJ9y2ZaJkBop/Cvm/fPrRo0UL5txSFk8/QNIjO71jlvNO3wpP3VM5LPTlfsy+XcORz4PDhw8r3IwXupHic3ujRo5VZHsI0qywXavTkoohku+XiglwYKerPg+nFCkvtAa31ORMR2RuDbiIiByZZLckEyh/ACQkJShViqU4sUzGXLFliFjCYrlfNj1SPlqnDUiFYH3AXJMtdUJI1k2OXaaNCMpwdO3ZUqhfLa5pOZ5dgo2nTpihpEpDoM3grVqxQLgrI1Gup/C3HWFJk6qwEh/r1zxIISl9kWXIgn6us39X3H5fMpKzXlv1l3f/ly5eVDKEEKfL59+7d25BF1QfdMtVcAjeZvi0BjGmgWRTyWckaYf1UZVmuIBeA5IKNfO6SlZTvRapjyzRp+dyF9OKWSvvyeUsfZvnezpw5YzbNXB/4SY0BqTouj5U11pLZlaDTdEaGzEiQadT5HauskV63bp3h/ZMp17KOXJZWmK75lmA0+/p1Rz4HJDCX4Fm+f5llIOv6Jcg1vehiGmjL7xZ5DiFZcKnILu+7VE3Pq+ZBXuRzlpkN+l7qcoFBai1IFl86IFjrcyYisjt7V3IjIqL8+3TnxZp9uvPqDe7j46OLjY3NsV9ePZBFbt+XVId+6KGH8u3Tffny5QI9n1Rbzm1b9vfXtAdybu/F5s2bzapCm1b7lqrduVVrtkb1ctPjEytWrDDbPnLkSMO2xMTEfHs0W/qcc+vT3bdvX7P7X331VYE/az2pop9Xn25LFavr1q2b575SPfz8+fPKvt9//32+z23alzq/Pt1SBb84fbqzf175VeXWwjlgWsHe0s3X11e3a9cuw/5HjhxRfgdk38/Pz0/XtWvXXH8e8novPvroI4uv3b9//yJ9zkREWmWfS7ZERGTVNY9SwVh6I0vmSzI/kjGSzJBkk2X6Z1EKDUlWTD8VVQwYMKDA68ILQqaNSkZbCiH169dPmRYv67BlzalkXmVau1Qrzmt9qi1Jdk2ywZIxlIJWclxynJJtLOnKzrLG3jTbL7MYrl69qvxbsnxynN99951yfDIVV95HOQckYyhThz/77DNDRXq9//3vf0rfdNlHpvnK+mGZeWDavzl7trOg5LOUc1KeS6YEyxRz+bxlDbjMoJA1ydIPXNYr602bNk3pES8ZcZlRIMck35vsLxXxpSe7nN/6z0Z6yMv7Iscv2VH5niVbK5lbyZrOmjWrQMcqryXnmewvFbzlc9Y/l6whl37mMruhuDUCtHYOyHsv2XvJWEvGWX5vyHkuPdvlZ1+y4JLd1pOq4ZLVlqrn8rzyu0G6GMjnLEXXiuL5559XMvbympbeX2t+zkRE9uQmkbddj4CIiDRLAnf9VGpZM66fmkqOT4phWSpwNm/ePGVNv55MT7bXhQ8iIiJnwDXdRERkRqokSwE2aRekD7jr1KmjrLcm5yEzI6TVk3yukkFOTEzE5s2bsXjxYsM+DzzwAANuIiKiYmKmm4iIzMhUdCmUpCcViqU6sUzxJOchywWkCF9upFq7zG6w5pICIiIiV8RMNxERWSTrRGUd5xtvvMGA2wnJul25oCJtoaTfcXp6utJ2SXqRSyVzyYTbex0zERGRM2Cmm4iIiIiIiMhGWL2ciIiIiIiIyEYYdBMRERERERHZiMsv1srKysKVK1eU3o+yto2IiIiIiIgoP7JSOyEhQen04e6eez7b5YNuCbgrV65s78MgIiIiIiIiB3Tx4kVUqlQp1+0uH3RLhlv/RgUFBUHLGXnpmxseHp7nVRSiksTzkrSI5yVpFc9N0iKel6RFWQ5yXt66dUtJ4Opjyty4fNCtn1IuAbfWg+6UlBTlGLV84pFr4XlJWsTzkrSK5yZpEc9L0qIsBzsv81umrP3vgIiIiIiIiMhBMegmIiIiIiIishEG3UREREREREQ24vJrugsqMzMT6enpdl3XIK8vaxscYV0D2Ya3tzc/fyIiIiIiB8KguwC9165du4a4uDi7H4cE3tIHjv3EXZcE3NWrV1eCbyIiIiIi0j4G3fnQB9wRERHw9/e3W8ArQXdGRgY8PT0ZdLsouegifeWvXr2KKlWq8DwgIiIiInIADLrzmVKuD7hDQ0PteiwMuklIr0IJvOVc8PLysvfhEBERERFRPrg4NA/6NdyS4SbSAv20crkgRERERERE2seguwCYWSat4LlIRERERORYGHQTERERERER2QiDbiqwatWqYc6cOQXef8OGDUpm1t6V34mIiIiIiOyFQbcTkkA3r9tbb71VpOfdvXs3nnnmmQLv36FDB6XSdnBwMGxJH9zLTVpqyes1b94cL7/8svL6hSXPs2zZMpscKxERERERuRZWL3dCpoHm0qVLMXnyZJw4ccIwVqpUKbOq6FKUS6qiF6RydmGLfpUrVw4lRb7HoKAg3Lp1C3v37sX06dOxePFiJShv3LhxiR0HERERERGRHjPdJSAtDdi4EZg6FRg9Wv0q92XcFiTQ1d8k6yuZW/3948ePIzAwEKtWrULLli3h4+ODLVu24MyZM7j33ntRtmxZJShv3bo11q5dm+f0cnnezz//HPfdd59S4b127dpYvnx5rtPLv/zyS5QuXRqrV69G/fr1ldfp06eP2UUCaYX14osvKvtJm7ZXXnkFQ4YMwYABA/L9vqW1m3yPderUwaOPPoqtW7cqFwpGjhxplq3v2bMnwsLClPemS5cuSoBu+j0K+Z7k2PX3C/L+EBERERERZceg28YksF60CJg3Dzh0CEhKUr/KfRm3VeCdn4kTJ+L999/HsWPH0KRJE9y+fRv9+vXDunXrsG/fPiUYvvvuuxEZGZnn80yZMgUPP/wwDh48qDx+0KBBiI2NzXX/pKQkzJw5E//73/+wadMm5fnHjx9v2P7BBx/g22+/xRdffKEEzZK1LupUbz8/Pzz77LPK80RFRSljCQkJShAvFxp27NihXCiQ45ZxfVAu5PXlYoD+flHfHyIiIiIicm2cXm5j27dLxheoUgUIDDSOS4wn440aAV26lPxxvf3220rGVy8kJARNmzY13H/nnXfw22+/KZnrUaNG5fo8Tz75JB577DHl31OnTsXHH3+MXbt2KUFpbr3PFy5ciJo1ayr35bnlWPTmzp2LV199Vck0i3nz5mHlypVF/j7r1aunfD1//rySCe/evbvZ9s8++0zJqm/cuBF33XWXYQq9jJlOjZf3pijvDxERERERuTZmum1s61ZAlkubBtxC7su4bLeHVq1amd2XTK5knGXatwScMoVasuD5ZXIlS64XEBCgrKnWZ5UtkWno+oBblC9f3rB/fHw8rl+/jjZt2hi2e3h4KNPgi0rWrJv2t5bnf/rpp5UMt0wvl+OV7z2/77Oo7w8REREREbk2Zrpt7Pp1KVxmeZuMy3Z7kADZlASUa9asUaZ+16pVS5ma/eCDDyItn/nvXl5eZvcluM3KyirU/vrA2BYkMBb6tdkytTwmJgYfffQRqlatqqxpb9++fb7fZ1HfHyIiIiIicm0Mum2sbFl1Dbclt28D1atDE2Tds0wV10/rlsyuTMkuSZJ5lkJlso66c+fOyphUVpdCZ82aNSv08yUnJyvTx+W59NPG5fucP3++sj5bXLx4ETdu3MhxYUBeV2vvDxERERGRK9DZMClnDwy6beyOO4B9+9Q13NnXdGdkqNu1QKZb//rrr0pxMMk+T5o0Kc+Mta288MILmDZtmpJNlvXYssb75s2bhunheZFp6ikpKUpRtD179igtwySglu/L9PuUIm4yvV6KtE2YMEHJWpuSrLgUTLvjjjuUTHiZMmU08/4QEREREWmVTAKVmlayhFZm9EoC8o47gPbtpZ1wwZ7jfNx5LNy9EE/VeQoRiIAz4JpuG5MTrGtXQJb+Sqvsy5fVr3JfxmW7FsyePVsJLjt06KAElr1790aLFi1K/DikRZgUZnviiSeUad+ydlqOxdfXN9/H1q1bFxUqVFDWgEtl9h49euDw4cNo0KCBYR/p2y1BvHxvgwcPVtqTSYE1U7NmzVKmkleuXBnNmzfX1PtDREREROSsXZsOXj+IV9a+gsPRhzF7z2ykZKTAGbjpnC13X0iS7ZRpzVLES4pqmZKs6blz51C9evUCBX22vOIjH5P0sPb09CxQ1tdZSDZZipdJWzKpGO7qrHVOWvPzkRkGcuHC3Z3X8EgbeF6SVvHcJC3ieUnWsnGjGmBb6toUGSldi/Lu2rTj0g58sPUDZGRlKLFPZb/KmNZnGkr7lYYjxpKmOL28BEhgLSeYPVqDOZoLFy7g77//RpcuXZCamqq0DJMgc+DAgfY+NCIiIiIiKkbXpi55xENVgqvA38sft1JvoU2FNhhSewiCfHIPZB0Jg27SFLnC+uWXXyrVwuUKV6NGjbB27Vol201ERERERI7btSktjxnAFQIr4M0ub2LD+Q14qulTiLkRA2fBoJs0RdZRS6VwIiIiIiLSjvyWzObXtalyZXVt94YNaubbr1Qarh1yw759Xjh8GHj6aaBOaB3l5mwFixl0ExERERERUb5F0vQBs2SuJcCWLk36gDm/rk3BwerjZc23T+BtbMJ78EVpNEl4GRs2uKFRI+ddjsugm4iIiIiIiHIlGW59wJw9oJZxCZgl4y0BuGlgLhnujAy1a1NMjDruHhiNNXgTt3BReY5SgeXg5zkk3zXfjoxBNxERERERERW7SJpkvCUA109Br17dOAV9wgQgK/gs1mAKkhGrPN4HwaiMDkixsOZbgnV57sJ2fdIiBt1ERERERERUrCJp+XVtyozYh+2xU+EHtfd2KZRHV0xBIMrjRLY1315easB+/Lj5FHZHDbwZdBMREREREWm0QJkW5FckTQLkvKw7uw77g+ciIzZTyWCX9ayLzpgEXwRbXPMtGXRpey3Bt+kUdkedfs6gm4iIiIiISKMFyrQQeOdXJE22WyItgH84/AO+O/wdgssAobcA3aV2qBQ3HjEBPhbXfBe1z7eWudv7AMhxbdiwAW5uboiLiyvwY6pVq4Y5c+ZAC9566y00a9bMcP/JJ5/EgAEDivWc1ngOIiIiInK9AmV16wIVK6pf5b6My3YtkKy7BMaRkcCJE8Dly+pXuS/jst2SFSdXKAG3cHcHRt55N+Y+/CqaNfKBvz/QuDEwapR6cUGCbtMp7G6ZGRansDsiBt1OSoI/CYifffbZHNuef/55ZZvsozUSCMuxyc3T01MJ0seOHYvbchnMxj766CN8+eWXBdr3/PnzyjHu37+/yM9BRERERK6tIAXKtECy7RIYS4AsgXL2gDm3bPydNe5E9dLq3POhzYZiZJun0a2rO157Tf5uhvJVstf6Pt/6P/mrnd+AJ6d3QeVI9Q2QcdnuqDi93IlVrlwZP/zwAz788EP4+fkpYykpKfjuu+9QRS6faVTDhg2xdu1aZGRkYOvWrRg6dCiSkpLw6aef5tg3LS0N3laacxMsC0k08BxERERE5BoKWqBMC/IqkpYbfy9/vNnlTZyIOYEOlTvkua9MUT++6xZ6/fYyOhxU/+6/d/nTmDFoPzIyfHOdwu4ImOl2Yi1atFAC719//dUwJv+WgLt58+Zm+6ampuLFF19EREQEfH190bFjR+zevdtsn5UrV6JOnTpKAN+tWzcl25vdli1b0KlTJ2UfeW15zsTExEIdt2S4y5Urh0qVKuGRRx7BoEGDsHz5crMp4Z9//jmqV6+uHKuQKe7Dhw9HeHg4goKC0L17dxw4cMDsed9//32ULVsWgYGBGDZsmHIBIq+p4VlZWZg+fTpq1aoFHx8f5X177733lG3y2kLeR8l4d5V5NRaeI7/3VT9Ff926dWjVqhX8/f3RoUMHnJD5OkRERETk1Eyzu9k5YnY3Mj4SsclqOzC9UP/QfANu0SFuJT7Z2NAQcIsYtzDcPBOb5xR2R8Cg28lJlviLL74w3F+yZAmeeuqpHPu9/PLL+OWXX/DVV19h7969SqDZu3dvxMaqPzQXL17E/fffj7vvvluZUi0B7sSJE82e48yZM+jTpw8eeOABHDx4EEuXLlWC8FEy76QYJICXjLbe6dOnlWOVCwj66d0PPfQQoqKisGrVKuzZs0e54HDnnXcajv/HH39UAvapU6fi33//Rfny5TF//vw8X/fVV19VAvVJkybh6NGjygwBCdrFrl27lK+Skb969arZhY3CvK96r7/+OmbNmqUcm1x0kM+NiIiIiJybZG+lkJgUJDOVX4EyLTp4/SBeXvMypmyYguT05II/8MYN4PHH4TWgPwLjLilDad4BWNXvPfwyegMGTaigmYJyRcXp5UXRqhVw7VqJv6yHBHz//luoxzz++ONK8HjhwgXlvkzXlinnkmHVk0z0ggULlLXIffv2VcYWLVqENWvWYPHixZgwYYKyvWbNmkpgKOrWrYtDhw7hgw8+MDzPtGnTlKz0mDFjlPu1a9fGxx9/jC5duiiP12elC0MCaAl2JXOtJwH4119/rWS1hQT2EgRL0C0ZaTFz5kwsW7YMP//8M5555hmleJtkt+Um3n33XSVgzp7t1ktISFDWZ8+bNw9DhgxRxuT7l0y10L92aGiokpW3pCDvq55k0OV9EnIxo3///sqxFeU9IyIiIiLHINlbqVJuWr3ctKK3o2R3/zn3D+bumouMrAycjTuL7w59h2Et1L+7c6XTAT/9pC4Mj442jvfsCc+FC9Hc3x+9I9yVAmyOjkF3UUjALSX7SpBbER8nwaEEcBL4Scl++XdYWFiODHV6ejruMLmU5uXlhTZt2uDYsWPKffnatm1bs8e1z/ZbQKZzS4b722+/NYzJa8o07XPnzqF+/foFOmYJ5kuVKoXMzEwlwJZjluBXr2rVqoagV/+6UmhNAmBTycnJyvemP/7sReXk+NevX2/xGGR/mRou2fKiKsj7qtekSRPDvyULL+QigpbX3hMRERGRdQqUSQ9qfZ9uWcWotT7duZG/9ZceWYpvDxn//m9doTUGNRmU9wOvXgWeew5Ytsw4Vro08OGHgCS8JCCPioKzYNBdFLlkNm1JJ7eyZYsUfMtUZf0U708++QS2IoHviBEjlDXM2RUmeJQsuqzhlmnWFSpUyFEoLSAgIMfrSqBqmr3XKy0/vEWgLzxXUiQY15M13kIuVhARERGRcytKgTItkKz2vF3zsO7cOsNYv1r98EzLZ+Dh7mH5QTodIJ1+xo2TokzG8fvvl0DFGGfJfk6EQXdRFHKKt1XodMjMyCjSBybrrCVjLMGcrCfOTqZNS2ArU88liywkQysFv/RTxSVLrS9mprdjxw6z+7KOWtY+y7rl4pBjKcxzyOteu3bN0GLMEjn+nTt34oknnsj1+E3J1HgJvKXAmaxft3SMQrLxuSnI+0pERERE5GgS0xLx/pb3sf+6sX2utAQbUG+AIYGUw/nzwDPPAGvWGMciItRg+8EH4cwYdLsADw8Pw3Rm+Xd2kjkeOXKkssY4JCREyUpL1W5p06VfAy1Ts2U9t+wjQaistc7ej/qVV15Bu3btlKy67CPPK0G4rGE2nR5ubT169FCmikvVcDluqbB+5coV/Pnnn7jvvvuUquCjR49WKovLv2W6t0yBP3LkCGrUqGHxOWUttXw/UghNAmd5THR0tPIYeU+kGrkE5X/99ZdSZV32z94urCDvKxERERGRI7mRdANvbXgLF+LVmlFe7l4Y134cOlZRax/lkJUFSCwgTblNuxoNHqxOJ8+2RNQZMeh2EdJGKy9SpVumMw8ePFgpIibB6erVq1GmTBlluwSMUoV77NixmDt3rrIuWSqBm1bZlnXJGzduVCpxS9swWeMh2V5p+2VLcjVN2pnJ60pldgmOpbhZ586dDdXG5RhkjbUE0VKgTCqsS0As32NupGq5ZM8nT56sBPEyhV2/LlzGpUjc22+/rWyX79fS9Pb83lciIiIiIkey9uxaQ8Ad6B2ISZ0noX54LrWbJPE3fDiwbZtxrHJl4NNPgf8KDbsCN51ERi7s1q1bSoYyPj4+R2AqwZkUADPtB20v8jFlyPRyT8/cp2yQ09PSOSnkgoIUfJPMv7szlJYkp8DzkrSK5yZpEc9LKkpcIlPLz8Wdw1td30KFwAo5d0pPB2bMAKZMkdZDxvGRIyUrJRlBpzgv84olTTHTTURERERERAUiCUCZTp6amYogHwuB5r59UskZ2G9c7w2p17R4MdC5M1yRdi8bEBERERERkd1k6bLwxb4vcOLGCbNxH0+fnAF3Soq6brt1a2PALVnqCROAgwddNuAWzHQTERERERGRmZSMFMzcNhM7L+9U2oLN6jULZUup9ZJykDXbUij4+HHjmDQfX7JEDcJdHDPdREREREREZHAz+SZeW/eaEnCLhLQEnIo9lXPH27eB0aOBjh2NAbeXl7qWe88eBtz/YaabiIiIiIiIFBfiLmDKximITopW7vt7+ePVjq+iWblm5jtKv23puy39t/XatFHXbkuWmwyY6SYiIiIiIiLsv7YfL6992RBwh/uHY3qP6eYBd1ycOpW8Vy9jwO3nB8ycqU4zZ8CdAzPdRERERERELu6v039hwb8LlOJpolaZWpjUZRJC/EKMO/3+u9r26+pV41jXrsCiRWqFcrKIQTcREREREZEL+/rA1/jp6E+G+20rtsX4DuPh6+mrDkRFAS+8APz4o/FBgYFqL+6nn1arlFOuGHQTERERERG5sMpBlQ3/HlB3AJ5q/hTc3dwBnQ747ju1WFpMjPEB/foBCxcClY2Po9wx6CaX5Obmht9++w0DBgyw96EQEREREdlVt+rdcD3xOoJ9gtG3dl918OJFdSr5n38adwwNBT76CBg4UP6gttvxOhrOA3DSgDKv21tvvVVix9K1a1eLx5CRkVEiry/fa7Nm2SotQpahXEXfvv/9QiEiIiIiciHxKfE5xh5t9KgacGdlAZ9+CjRsaB5wP/wwcPQoMGgQA+5CYqbbCUlAqbd06VJMnjwZJ06cMIyVKlXK8G+dTofMzEx4etruVHj66afx9ttvm43Z8vUKoly5cnZ9fSIiIiIie9h9eTemb5uOka1Gonv17uYbT59W12hv2GAcK18emD8f4AzRImOm2wlJQKm/BQcHK5ll/f3jx48jMDAQq1atQsuWLeHj44MtW7bgySefzDHVesyYMUqmWi8rKwvTpk1D9erV4efnh6ZNm+Lnn3/O93j8/f3Njkkf8Mpzy2uYkmOQY9GrVq0apk6diqFDhyrHXaVKFXz22Wdmj7l06RIee+wxhISEICAgAK1atcLOnTvx5ZdfYsqUKThw4IAhwy5jQv69bNkyw3McOnQI3bt3V76v0NBQPPPMM7h9+7Zhu/79mTlzJsqXL6/s8/zzzyM9Pb0QnwwRERERkX1Isu2PE3/gnU3vICUjBXN3zcXJmJPqxsxMYNYsoEkT84B76FDgyBEG3MXETLeLmjhxohJA1qhRA2XKlCnQYyTg/uabb7Bw4ULUrl0bmzZtwuOPP47w8HB06dLFZsc6a9YsvPPOO3jttdeUIH/kyJHK69WtW1cJjOXfFStWxPLly5WAfu/evcoFgkceeQSHDx/GX3/9hbVr1yrPJRchsktMTETv3r3Rvn177N69G1FRURg+fDhGjRplCNLF+vXrlYBbvp4+fVp5fpm6Lpl8IiIiIiKtysjKwKI9i7Dy9ErDWLuK7VCtdDXg8GG17/auXcYHVK2qtgHr2dM+B+xkGHQX0bLjy5RbfmqWqan0tzP1zsZ3cObmmXwfO6DeAOVmCzLdu2chfohSU1OVjLMErxKcCgnYJUv+6aef5hl0z58/H59//rnh/ogRI5RAuqD69euH5557Tvn3K6+8gg8//FAJfCXo/u677xAdHa0Ey5LpFrVMegTKVHqZyp7XdHJ5jpSUFHz99ddKplzMmzcPd999Nz744AOULVtWGZOLEzLu4eGBevXqoX///li3bh2DbiIiIiLSrMS0RLy/5X3sv77fMPZwg4fxeL2H4fbuNOC99wD97E1Zqy2twWTMZEkqFQ+D7iJKSk9CTLJJ2fxchPmH5RiLT40v0GPlNWxFpmAXhmR2k5KScgTqaWlpaN68eZ6PHTRoEF5//XXD/dKlSxfqtZvINJf/6KfKSzZa7N+/X3l9fcBdFMeOHVOmyusDbnHHHXco2XJZC68Puhs2bKgE3HqS9ZZp6UREREREWnTt9jW8vfFtXLx1Ubnv6e6JF9q8gO43AoHWrWWNpXHnunWBxYvlD2H7HbCTYtBdRP5e/gj1C813Pym7b2msII+V17AV0wBTuLu7K+s8TJmuV9avb/7zzz+VqdymZF14XmRKt2n2uaCvqefl5WV2XwJvCYiFrMEuKXkdBxERERGRlhyNPop3N72LhLQE5X6gdyBeb/MSGs5dqq7f1v8dK0mlV14BJk0CfH3te9BOikF3ERVn6nf26eZaIOuyZf2zKcki6wPNBg0aKMF1ZGSk1dZvy2uaVlqXKupyDN26dStUFlymrsfGxlrMdnt7eyvPm5f69esra7dlbbf+YsTWrVuViwIyhZ2IiIiIyNHWcM/cNtMQcFcKrITJ3j1RvvujaoVyPWmtu2QJkM/MVSoeVi8nhVTu/vfff5V1zadOncKbb75pFoRL5fDx48dj7Nix+Oqrr3DmzBmlYNncuXOV+0V9Tcmcy02qqkuBtLi4uEI9h1Qtl+nmUllcAuWzZ8/il19+wfbt2w3Vz8+dO6dcQLhx44ayNt3S9HdfX18MGTJE+Z5lvfgLL7yAwYMHG6aWExERERE5CplG/sodr8DL3QvNytTHzJXpKN/7AWPA7e2trtuW4mkMuF0v6P7kk0+UQEmCoLZt22KXaRU9C+bMmaNkI2WaceXKlZWgUIpiUeFI9e5Jkybh5ZdfRuvWrZGQkIAnnnjCbB+pIC77SBVzyQ736dNHCZilhVhRSBswCXTldSR7LoXZCpPl1mey//77b0RERCgF1xo3boz333/fsPb6gQceUI5Tnlcy699//73FlmarV69WsuXyvT/44IO48847laJpRERERESOqG5YXbzv1RdvjvgeAQsWGzd06CBTWoHXXpP1k/Y8RJfhpsu+qNaOli5dqgRg0pJKAm4JqH/66SelmJUEVZaqTkvgtmTJEnTo0AEnT55U+ik/+uijmD17doFe89atW8qa4/j4eAQFBZltk+BdsqQSVMpFAHuSjykjI0OpxC1rick1aemcFLKmXYrayc+nTMcn0gKel6RVPDdJi3heFk9aGiATLLduBa5fB2SSpNQhk2Y/kkwuKTFJMfjj5B94oukTcHdzB2JigLFjgf/9z7iTLKOcNg2QrkAmxYG1KMtBzsu8YknNrumWQFnaLz311FPKfQm+JZMqQbX0lc5u27ZtSpXpgQMHKvclQy7TjXfu3Fnix05ERERERNoKRm39PUor6w0bAE9PtcOWFAPft09tfS1dZUviez0Vcwrvbn4Xscmx0Omy8NTpQGDUKOC/bj+KHj3Ug61WzfYHRNoNuqX11J49e/Dqq68axuSqRo8ePQzrc7OT7PY333yjTEFv06aNsp535cqVylrc3MiaXtN1vXJ1Qn81JXslarkvGWb9zd70x6CFYyH70J+Lls5Xe9D/jGjhWIj0eF6SVvHcJFc5LyUYlc5TGzcag1EJQmVGs3wdNqxkglE5jh07JFGnxp8ycVZmVrdrZ53Xl+eV77FKFal/ZBxPSFDHGzYEOneGTW2J3II5O+cgLTMNSEnF5i+m4OGFkQj4ryGQLjgYupkzAUlqymxZB/n9k+Ugvy8LenyaCbqlyJVUmc5euEruS5EtSyTDLY/r2LGjYfr1s88+i9dkfUIuZD3ylClTcoxHR0fnWAsu7avkjZTnlZs9yfenr8LN6eWuS85DOSdjYmJytDCzBzkWmU4j56eWp/6Qa+F5SVrFc5Nc5byUwPrUKaBpU2nvahxPTlbHt2wBGjWCTUkX2r//VjPPMpNaZlbLjOtly4AzZ4BevYq/nPngQaByZSBbN13oZxnL9nr1YBPyef12+jcsO71M7sH9yhU02X4ar69NMwTcKX364Na0acgqV06CHTiSLAf5fSl1sBwq6C6KDRs2YOrUqZg/f76yBvz06dMYPXq0oeCXJZJJHzdunFmmWwqwSZEtS2u65Y2UddRy0wItBFpkP3Ieyi+e0NBQzazplotA8vOj5V+I5Fp4XpJW8dwkVzkvpQ6yZJZLl1aDX1MyLtu7d4dNbdoErF6tBsX6LLT8GS0xkozXrFn8LPTZs0BSksQTObfJ68h2C2Wpik2y2nN2zMGWi1vgnZkJt0MH0WPnDTy3G/DKAnTh4dB9/DG8H3oIYQ6arMtykN+XBf17XBuRJICwsDCl4vR1WfRhQu5LSyhLJLCWqeTDhw9X7kvlaum1/Mwzz+D111+3+AFJr2m5ZSf7Zt9f7suHrb/Zk1zl0R+DvY+F7Ed/Llo6X+1Fa8dDJHheklbx3CRXOC/lz3nJLFsi47Ld1j8CMvVbMtym076F3Jdx2d61a/FeQwJqyaRbIkG3LJ+29vcp67bf3fQuTsWchNv5C3A7dgxP/ZuJAccBJUJ4/HG4ffgh3MLC4OjcHOD3ZUGPTTPfgbR+atmyJdatW2d2hUPut5eKCxYkJSXl+Eb1raKsue5Z62sJyHVwPT8RERFpnawWvX3b8jYZz7aa1CYksJe15JbIeLY8X5FIYThZgZp9hrHcl3HZbk2Xb13G2NVjcSpyn3LVwHf/YbyxPhP3ScBdqRKwYoVardwJAm5no5lMt5Bp39K3uVWrVkphNGkZJplrfTVzaSdWsWJFZV22uPvuu5WK582bNzdML5fst4zrg+/iXgiQoP7KlSvK1Aa5b68sM1uGkZwDUntAPn8uMyAiIiKtkmBTKnhL8Jm9wJgtglFLJLDPLQstgX/16sV/DckLyvp10+rl8tzyPUoWPZe8YZGFeZdGmcOnEXt8ByISsjBpE1AtDsCIEcD06cbF5KQ5mgq6H3nkESWomDx5Mq5du4ZmzZrhr7/+MhRXi4yMNMtsv/HGG0oAIl8vX76sBMYScL/33ntWOR55LemHfPXqVSXwtid99T79lHdyTfLZV6pUySoXlYiIiIhsoaSDUXsF/lIBXdqCSVE4fWs0CeZt0hrtwAH4DB2KN47vxaIWUNZvB1eqCfz2eYHmybtCCzctc9O5+HzVgjQ012eZ9dXD7UFfsVoKaGl5XQPZlmS4tRRwy3kZFRWFiIgInpekGTwvSat4bpIrnZf2DvIs9dA2DfxLqod2caRkpOB2QgzCZi0APvhAPXghn9PYscDbbwP+/k75XmQ5yO/LgsSSmst0a5V+Oq89p/TKiSevLxXytHziERERERFJENeli3qz1+uXWBbaBqISo/DO0ueh+2cdZiy9CT9992Jp/i1N0Nu2LfBzycUPCbgt9ROXcXmP7PU5uQoG3URERERE5HTsHfgX1bHIPXhv4SDEnz+h3J/fGnhptyfw+uvAa68V+oqBXHSQDLelSu4yLtsd7T1yNAy6iYiIiIiINGDtz9PxycopyEhJUu6Xvw087NEE2PuN9Ecu0nT+kqjkTnlj0E1ERERERGQjBVnfnhkbgy+nDMCy6C2GxzW94YFXer6FwLET1ZR0IdZsS+V2KSQnBe1CQ4GoKNtWcqe8MegmIiIiIiKygfwCYll3nvbnUkz/Yjj2lTI2N79LVwvD5i6HZ936xV6z3bGjsZ+4vVq4uToG3URERERERDaQV0C8d3U0jqwYhgVef+Dqf9O/Pdw98GyrZ9HnxY/VKuVWWLMdH69WKbdnCzdXx6CbiIiIiIjIBiwGxDodOpz/Hr1Wv4hVtWNwtYk6HBRaHq8O/wqNmvYs1Gvkt2Y7JgaYONFxK7k7AwbdRERERERENpA9IA66dQn9/xyJuidXKPcfOgKcqeiHq3064o2n/4eIUmUL/RqyRlymrOe1ZttRK7k7CwbdRERERERENmAIiHU6tNj7OXqueQl+qQmG7ccbPYixS2YqO/p6+hbpNSRjLWvEuWZbuxh0ExERERER2YAEvFc2n8GgL55GmZj1mNQFeHI/UCWxHD5t8gnaTLsf9SsW7zVkirgUZeOabe1i0E1ERERERGRtmZm4Y+dH6LDuDZwOTsbY3kCcLzCmTw2E3l6Fzt3rWCUglqnjUgWda7a1i0E3ERERERGRNR05AgwbBs+dO/FXLeDTlkCyjx8uhzaBT+lGGNHWD/d2s15AzDXb2sagm4iIiIiIXLaPtrT10meIZQ12sTLE8oQffAC88w4yMtPxWWtgVS0A1avBq1493FOhBSZ2nIhAn2z9vcipMegmIiIiIiKXI/HxokXma6Gl6JkUJZM10jJlu1CB97//KtltHDyoTCOf1gU4Wj0AaNoUCAnBPXXuwdDmQ5Ve3ORaGHQTEREREZHLkQy3BNxVquSs+i3jska6QNO1k5OBt94CZs4EsrJwOgR4r7MbbjSuAdSpCy8vHzzf+nncWeNOW347pGEMuomIiIiIyOXIlHLJcJsG3ELuy7hszzfo3rxZzW6fOqXcTfEEJt9XGgmtGgPBwQj1C8VrnV5DndA6tvtGSPPc7X0AREREREREJU3WcMuUcktkXLbnStLhzz8PdO5sCLhlLrrvW+9i5Ku/KgF3/bD6+LD3hwy4iZluIiIiIiJyPVI0TdZwWyJ9rqXtlkWrVwPPPANERhrH2rUDFi8GGjRAJwBePn5oWb4lvDy8bHLs5FiY6SYiIiIiIpcjVcozMtSktSm5L+Oy3UxsLPDkk0CfPoaA+1x5Xyyd/gSwZYsScOu1q9SOATcZMNNNREREREQuR9qCSZVy0+rlkuGWgLtrV3W7wS+/qNPJTeacb7qvBT66OxxpvrEIu7CBhdIoVwy6iYiIiIjI5Ug7MGkLJlXK9X26ZUq5WZ/ua9fUYPvXXw2PywwOxFdT7sdv4TGAmzq25uwadK/eHW5u/w0QmWDQTURERERELkkCa6lQnqNKuU4HfPU1MHYscPOmYfjWgL6Y/mRtHEg6axjrUb0HRrYeyYCbcsWgm4iIiIiISO/CBWDECLVgml5YGM7OfB3vlT6IqP8Cbg83DzzT8hn0rdWXATfliUE3ERERERFRVhawYAEwcaK6uFtv4EBsmPAQ5p78H9KS0pSh0r6l8WrHV9Eg3Fg8jSg3DLqJiIiIiMi1nTwJDBumViHXq1hRCcJX1vPAgn8XGIbrhtZVAu5Q/1D7HCs5HLYMIyIiIiIi1ySlyj/4AGjSxDzglj7cR44Ad9+N9pXaI8QvRBnuWaMnpt05jQE3FQoz3URERERE5HoOHFCz23v2GMdq1AA+/xzo1s0wVMavjJLZPnfzHPrU6mOX9dtpacD27cYq62XLZquyTprGTDcREREREbmO1FRg0iSgVStjwO3urlYqP3gQ66sBCakJZg+pF1YPfWv3tVvAvWgRMG8ecOgQkJSkfpX7Mi7bSduY6SYiIiIiItewY4ea3T561DjWoAGwZAnSWjXHwn8XKj23W5ZvicldJsPdzf45Sslwb9gAVKkCBAYaxxMS1HHpM56j5Rlpiv3PIiIiIiIiIltKTATGjQM6dDAG3J6easZ7715EN6qBiWsnKgG32HN1D/Ze3QstkCnlcqimAbeQ+zIu20nbmOkmIiIiIiLntX49MHw4cFbtr61o2VLJbksBtYPXD+KDrR/gVuotZZO3hzdGtR6FVhVaQQtkDXepUpa3ybhsJ21j0E1ERERERM4nPh6YMEFd+Kzn6wtMmaJkvXUeHvj16C/46sBX0EGnbC4bUBavdXoNNcrUgFZI0TRZw22JtBOvXr2kj4gKi0E3ERERERE5lxUrgGefBS5fNo516qRWJq9TB0npSZizZQ62X9pu2NyiXAuM7zAegT7Z5nHbmVQp37dPXcOdfU23dDyT7aRtDLqJiIiIiMg5REcDo0cD339vPgf7/feBkSOVKuVSmXzCmgm4nGAMyB9t+Cgea/yYJgqnZSdtwQ4fVoumyRpu+XYkwy0Bd9eu6nbSNgbdRERERETk2HQ6YOlS4IUXgBs3jOO9ewOffgpUrWoYKuVdCnVC6yhBd4BXAF5q/xJaV2wNrZI+3E8/rVYp1/fplinl7NPtOBh0ExERERGR45Ip5M89ByxfbhwrUwaYMwcYPBjI1ltbem0/3/p56HQ6DGoyCOVKlYPWSWAtbcHYGswxaW/+BBERERERUUGy27JGu2FD84D7gQfUtmBPPKEE3LHJsTgSdcTsoT6ePnipw0sOEXCT42Omm4iIiIiIHIu0/3rmGWDdOvMy3598ogbd/5FgW9qBpWWm4cPeH6J8YHn7HC+5NGa6iYiIiIjIMWRmqtPGGzc2D7iHDFGz2/8F3DJ1/Ldjv+G1f17DzZSbSExPxKK9Jq3DiEoQM91ERERERKR9x44Bw4YB241tvlCliloorU8fw1BiWiI+3vkxtl3aZhhrEtEEo9uOLukjJlIw6CYiIiIiykdamhrr6atHy0xmVo8uIenpwPTpwNtvqx+E3vPPA9OmmTWvPh93HlM3T8XV21cNYw83eFgpmKbFdmBawnPcdhh0ExERERHlE4wsWmTeJ/nQIWDfPrV/srRzYlBiI3v3AkOHAgcOGMdq1wYWLwY6dTLb9Z9z/+CT3Z8o67eFI7QD0wqe47bFoJuIiIiIKA+S/ZNgRGYymyRVkZCgjkv/ZLZysrKUFGDKFGDGDHUdt/DwAMaPB958E/DzM9v9y/1f4pdjvxju1yxTE692fBVlS5Ut6SN3SDzHbYtzLIiIiIiI8iDTbSX7ZxqMCLkv47KdrGjLFqBpU+D9940Bd5MmwM6d6li2gFs0DG9o+Hfvmr0xved0BtyFwHPctpjpJiIiIiLKg6xvlem2lsi4bCcruH0bePVVte2X9OAWMqd50iTg5ZfznN8sU8gfb/w4wvzDcGeNO0vumJ0Ez3HbYqabiIiIiCgPUlBK4kFLZFy2UzH9/bc6h3nePGPA3batuqj4jTfMAu7MrExsPL9RaQtm6pFGjzDgLiKe47bFTDcRERERUR6kgrPEfrK+Nft614wMdTsV0c2bwLhxwJdfGsdk+vh77wEvvqiu4zYRmxyLGVtn4HD0YdxOu43+dfrb9PBcpaI3z3HbYtBNRERERJQHCbCkgrNpZWfJ/kkw0rWrup2K4LffgOeeA65dM451766W0a5RI8fuB68fxIxtMxCXEqfc/2L/F+hUtROCfIJscniuVNGb57htMegmIiIiIsqDBFYSYMnsZ33Gs3p158x4lgh5A194AfjpJ+NYUBAwcyYwfDjg5ma2u0wj/+noT/jm4DfQQZ1SHuoXilfueMVmAberVfTmOW5bDLqJiIiIiPIhQYcEWM4SZNllyrWswf7mG2DMGCA21jh+113AggVApUo5HpKQmoBZ22dhz9U9hrHm5Zor/beDfYNh74reznQ+8By3HQbdRERERERk2ynXFy8CI0YAq1YZx8LCgI8/Bh59NEd2Wxy/cRwfbP0AN5JuKPfd4IaBjQfi4YYPw93N9vWgWdGbrIVBNxERERER2WbKdVYW8Nlnassv2UlPAm0JuMPDLb7Orsu7MHXzVGTq1D7dwT7BGN9hPJqVa4aSIhl8uaBgiax3lunXRAXBoJuIiIiInJqrVKAuCYWacn3qlJr63rjRuGOFCupU8nvuyfN1GoQ3UNZtRyVFoUFYA7x8x8sI9Q9FSWJFb7IWBt1ERERE5LRcqQJ1SSjQlGuJSD/8EJg8GUhJMe4gRdJmzABKl873dUp5l8LEjhOx/dJ2DGo8CB7u5q3DSgIrepO1MOgmIiIiIqflShWotTDluo3fIaD9UODff40bZB62XPm4806Lj5Pq5KtOr0K7Su0Q4hdiGK8dWlu52QsrepO1MOgmIiIiIqflahWo7TXlOikuDfcfeA8PnJwKZGaog1IcbfRo4N13gYAAi8+XmJaIj3d+jG2XtmHzhc14t/u7dslq54YVvckaGHQTERERkdNiBWrbT7kue2EXRu0ciiq3jhh3rF8fWLw4zznYp2JOYfrW6biWeE25fzj6MPZf24+WFVqWxLdCVGIYdBMRERGR02IFattNud65Pgmtlk9G1/0fwl2Xpe4gkfjEicAbbwA+PrlOJ//j5B/4Yv8XyMjKMKzhHttuLANuckoMuomIiIjIabECtY2mXOs2oMs3w4EzZ4wbWrQAliwBmjbN9bG3027jox0fYcflHYaxOiF18ErHVxAREGHrQyeyCwbdREREROS0WIHaym7dUntuf/qpcUwy2lOmAC+9pL7JuThx44QynVzagOndV+8+PNH0CXi6Mywh58Wzm4iIiIicFitQW9HKlcCIEcClS8axjh2Bzz8H6tbN86FXE67ilbWvIFOXqdwP9A5UppO3rtja1kdNZHcMuomIiIjIqbECdTHduAGMGQN8+61xTKqRf/ABMHIk4O6e71OUDyyP3jV7Y+XplagfVh8v3/EywvzDbHvcRBrBoJuIiIiIiHLS6YCffgJGjQKio43jvXoBn30GVK1aqKcb1mIYKgZVRP/a/TXVFozI1hh0ExERERGVgLQ0YPt24zR3qayu2WnuV64Azz8PLFtmHCtdGvjwQ2DIELUHdy6ydFn48ciPSmG07tW7G8a9PbxxT917bH3kRJrDoJuIiIiIqAQC7kWLzAu6SSszqawuhd5k3bkmAm/Jbn/xBTBuHBAfbxy//37gk0+AcuXyfHhscixmbZuFg1EH4ePhgzqhdVApqJJNDtWhLmKQS2PQTURERERkYxIcSsBdpUrO1mUyLoXe7L7m/Px54JlngDVrjGMREWqw/eCD+T58z5U9mL1jNm6l3lLup2Wm4Wj0UZsE3Q5zEYOIQTcRERERke1JNlaCQ9OAW8h9GZftdgu6s7KAuXOB118HEhON4088AcyeDYSG5vnwjKwMfH3ga/x2/DfDWKhfKMZ3GI9GEY1c9yIG0X8YdBMRERER2ZhMf5ZsrCUyLtvt4vhxhDz5JNx37zaOVaqkFkrr2zffh1+7fQ0zts7AydiThrHWFVpjTLsxCPIJcs2LGETZMOgmIiIiIrIxWW8s058tuX1b7R1eotLTgZkz4TZlCrxTU43j0gLs/feBoPwD5i2RWzB311wkpScp9z3dPfFUs6dwd5274ZZHoTWnvohBZAGDbiIiIiIiG5MCX7LeWKY/Z58OnZGhbi8xciDDhilf9aGxrlYtuH3+eYHTwykZKfh87+eGgLt8qfJK7+1aIbXgkhcxiPKQfyd7IiIiIiIqFqmo3bUrEBkJnDgBXL6sfpX7Mi7bbS4lRV233bq1GnhLsO3ujsSRI6GT+4WYj+3r6YuX2r8EN7ihS9Uu+KjPRyUWcAu5SCEXK+SihSm7XMQgygcz3URERERENiaVtKWithT40re4kmxsibW42rZNzW4fP24ca9QIukWLkFCtGvz8/fN8uE6nQ2pmqhJs6zUu2xhz+sxB9dLVbT6dPDt5z6RKuWn1cslwS8BdYhcxiAqIQTcRERERUQmQwFqSySVa4EsiUcluS3Vy6cEtvLzUsVdfVSPWqKg8n0JagM3ZMQeZWZl4q+tbZgF2jTI14JIXMYgKgUE3EREREZEzkn7b0ndb+m/rtWkDLF6sRqv6dmF5OHj9IGZtn4XY5Fjl/rLjy3Bf/fvgshcxiIqAQTcRERERkTO5eRMYPx5YssQ45ucHvPMOMGYM4OGR71NI7+3vDn2Hn4/+DB3UDHmwTzCqBFex5ZETOSUG3UREREREzmLZMrXt17VrxjFZ5LxoEVCrYIXOriZcxcxtM816bzcr2wxj249FiF+ILY6ayKkx6CYiIiIicnSyLvuFF4AffzSOSW+yGTPUxc/u+TctkmJp68+vx4J/FygtwYSHmweeaPoE7qt3X4kXSyNyFgy6iYiIiIgclRRH++47YPRoICbGON6/P7BwIVCpUoGeRoqkzdk5B5siNxnGpPf2hA4TUDu0ttUPOy0N2L7dWARN+m6zCBo5KwbdRERERESO6OJFdSr5n38ax0JDgY8+AgYOBAqRmfZw94CXh5fh/p3V78SIliPg5+Vnk4BbZrubtvs6dEhtHS5twCQxz8CbnAmDbiIiIiIiRyIVxyVqnTABSEgwjj/8sNoaLCKiSE8rQXZkfKQylbxT1U6wFclwS8BdpYo6A15PvhUZl8LqrEhOzoRBNxERERGRozh9Wk0FS3SqV748MH8+MGCAOm17Y/7TtqMSo3D51mU0LdvUMCZZ7Vm9Ztl87bYcm2S4TQNuIfdlXLYz6CZnwqCbiIiIiEjrMjOBOXOASZOA5GTj+NChwMyZQJkyBZ62veH8BqVYmvio90dmL1MSxdLkYoAcmyUyLtuJnEn+ZQxL2CeffIJq1arB19cXbdu2xa5du/LcPy4uDs8//zzKly8PHx8f1KlTBytXriyx4yUiIiIisimJmDt0UHtv6wPuatWAv/8GFi9WAu7s07br1gUqVlS/yn0Z/2dLotIKbNb2WUhKT1JuXx/4OsfLSfC+cSMwdapan02+yn0ZtwbJvt++bXmbjMt2ImeiqUz30qVLMW7cOCxcuFAJuOfMmYPevXvjxIkTiLCwNiUtLQ09e/ZUtv3888+oWLEiLly4gNKlS9vl+ImIiIiIrEai3GnTgPfeA9LT1THJREtrMBnLli7Oa9r2bf/DeH3LbFSoHW0Y71atG55u/jQS4xJLtMiZTHeX55M13NnXdGdkqNuJnImmgu7Zs2fj6aefxlNPPaXcl+D7zz//xJIlSzBx4sQc+8t4bGwstm3bBi8vtdqiZMmJiIiIiBza7t3AsGFqxKsnaWvJbOcSlVqatp2FDBzC9zha/ie4p+pQAUCAVwCea/0cOlftjKysLCQisUSLnMn6cgngTQN7yXBLwN21q7qdyJloJuiWrPWePXvw6quvGsbc3d3Ro0cPbJeffguWL1+O9u3bK9PLf//9d4SHh2PgwIF45ZVX4OHhYfExqampyk3v1q1bylf5hSM3rZJj0+l0mj5Gcj08L0mLeF6SVvHcpAJJSoLbW28BH34It//OFZ38Xfvyy9C98Qbg66tWL7dApmVLMKt3C5exHbMRi1PIyAQC/YGGYY0wtt1YhAeEG/7+NT0vJVsuuSxL2XIZl+2dilnYXAJtuZ7QsCGwbRsQFQVUr67OoG/XTt2ekgLs2GHcLpNe9dvZTsz5ZTnI78uCHp9mgu4bN24gMzMTZbMt4pD7x48ft/iYs2fP4p9//sGgQYOUddynT5/Gc889h/T0dLz55psWHzNt2jRMmTIlx3h0dDRS5Kdbwx9ofHy8cvLJxQgiLeB5SVrE85K0iucm5cdr2zYEjx8Pz3PnDGPpjRohfvZsZDRuLNki9ZaLNm3UAFWCYx/fTPyVOhEJWdFwywJ8vd3xcJ0H8WyjftAl6pTq5ZbOS8k2SwAcFJTz+WVctstrWEO9eurNVFycOpNelqtLkl+uNwQEADExwLJlwJkzQK9e6vdIzivLQX5fJpi27HOEoLuoH4as5/7ss8+UzHbLli1x+fJlzJgxI9egWzLpsm7cNNNduXJlJUseZOm3i4a+V6kmKcep5ROPXAvPS9IinpekVTw3KVe3bsFt4kS4ffqpYUjn7Q3dm2/C46WXEFLACLNjR+DECbXomWSLy5cZg4v+U+CbVhGP1xyHVwfUzpElzn5eyuMk32XpJeVagEwvL2Ib8ALbtAlYvRqoXNmYcZfjkfhGxmvWBDp3tu0xkH1lOcjvSyn+7VBBd1hYmBI4X8/WI0DulytXzuJjpGK5rOU2nUpev359XLt2TZmu7m1h7olUOJdbdvJhavkDFXLiOcJxkmvheUlaxPOStIrnJuWwahUwYgRw8aJxrEMHuC1eDLfsaeB8eHln4umnPZTAWO3T3Rrhwa/gwQ6t0bWjb67Tsk3Py7yKnEkGWrbb+vSVKeXy572lKe4yLttl7Tc5NzcH+H1Z0GPTTNAtAbJkqtetW4cBAwYYrnDI/VGjRll8zB133IHvvvtO2U//DZ88eVIJxi0F3EREREREmiDzpceOBf73P+OYv79arfz559XosoBSMlKweO9i3Ey5idc7vY4uXdxMip11crgiZ+zjTc5GU5cNZNr3okWL8NVXX+HYsWMYOXIkEhMTDdXMn3jiCbNCa7JdqpePHj1aCbal0vnUqVOVwmpERERERJqj0wE//QQ0aGAecPfooUa7L75YqID7ZMxJjF41Gn+d+Qs7L+/E32f+LtbhSd5K2oJJzkuWkct1APkq963RLqwg2MebnI1mMt3ikUceUQqaTZ48WZki3qxZM/z111+G4mqRkZFmKXxZi7169WqMHTsWTZo0Ufp0SwAu1cuJiIiIqHikZ7M0kVGnKqvBjkwvlmwnJxUWwdWrahb7t9+MY8HB0jcXkCST9OAuoMysTPx45Ef8cOQHZOnUCso+Hj7wcC94wJ4b+WwlU17c1mBFxT7e5Gw0FXQLmUqe23TyDTLPJRtpGbZD+gkQERERkVUD7kWLzKcZSzVpCYYkIVtSWU+nyW5/+aVM61TLc+vdey8wfz5QoUKhLnJcSbiC2dtn40TMCcNYnZA6GNd+HCoGVYSj08IUdyKnDrqJiIiIyP4k+JOgp0qVnNlGGZdiXfbKhDqU8+fVQmnSA0svPByYOxd4+GFDdrsgFzm8vHRYdXoVluxbgtTMVOVx7m7ueLTho3io4UPwdHeOP+31U9yNBeHUdmWcZUGOyjl+MomIiIjIqiTYkeDPUgVpGZftDLrzkJWlZrEnTgQSE43jjz8OfPihtO4p1EWOug3SsMVtKvZc3WPYVr5UebzU/iXUDasLZ2PvKe5E1sSgm4iIiIhyYAXpYpBm2cOGqVcm9CpVAhYuBPr3L9JFjl3bvRHQJcAw3rdWXwxtPhS+ngXrE0xE9sOgm4iIiIhykPXEMr3ZEllfK9N9KRtpZD1rFvDWW0CqOv1bIdPLp08HgoKKdZHj3VbP4urtqxjUeBBaVmhpg2+AiGyBQTcRERER5cAK0oW0fz8wdKj6punVrAl8/rla/auQFzmuYT8ykYaKaGO4yBHoE4hZvWbBrRBVzonI/hh0ExEREZHNKkg7fdsxyWi/8w7wwQfqmyOkxe3YscDbb6uNrgtxkeNmQirOBH6JU1gBbwSiU8I8ZGSEGC5yMOAmcjwMuomIiIjIJhWknb7tmFxNkLXbx44Zx+QNW7wYaNOmUE8l7+na/cfx7dkPkZp4BR4ewO3MBOyJX4Unug5imywiB8agm4iIiIhsUkHaaduOSTXy118HPv5Y7cEt5KqCjL32WqGvJKRnpuP7o99hX9gvKOehQ2wskJHqjfalnsSw/nehQwcHvzhB5OIYdBMRERGRTThl27F169QU/blzxrFWrYAlS4DGjQv9dGdvnsXs7bNxIf4C3NyB0FCgQ526GNtuLCoGVbTusRORXTDoJiIiIiKbcKq2Y3FxwIQJamE0PV9fdT33mDHqVYRCyMzKxE9Hf8IPh39Api5TGfN098TARgNxf/374eHuYe3vgIjshEE3EREREdmE07QdW74cGDkSuHLFONa5sxqA165dpKe8nXYby08sNwTc1UtXx7j241CtdDVrHTURaQSDbiIiIiKyCYdvOxYdDbz4IvDDD+Ypeum5Lb23pUp5EQX7BuO51s9h+tbpeKjBQ3is8WNKphuuXg2eyAkx6CYiIiIiTbcdK3FSHO3779WAOybGON6nD/Dpp2pluEK6mnAVAd4BCPIJMox1rNIRNcrUQIXACgV6DqevBk/kpBh0ExEREZFm246VuEuX1KnkK1YYx0JCgDlzgMcfl0bZhXo6nU6HFSdX4MsDX6JNhTZ4peMrZtsLGnA7dTV4IifHoJuIiIiINNt2rESz27JGe/x44NYt4/hDDwFz56rzuAvp2u1r+GjHRzgcfVi5v+XiFnS51AXtKrUr0iE6ZTV4IhfAoJuIiIiIXNuZM2pKfv1641i5csD8+cB99xX66SS7ver0Knyx/wukZKQYxu+qfRealWtW5MN0qmrwRC6EQTcRERERuabMTOCjj4A33gCSk43jTz4JzJ4NlClT6KeMSozCxzs/xoHrBwxjEf4RGN1uNJqUbVKsw3WaavBELoZBNxERERG5niNHgGHDgJ07jWNVqwKffQb06lWk7PbqM6uxeN9is+x2n5p9MLT5UPh5+RX7kB2+GjyRi2LQTURERESuQ0qAf/AB8M47QHq6OibF0UaNAqZOzX3+dj72XduHT3Z/Yrgf5h+GF9u8iOblm1vryB23GjyRi2PQTUREREQOrcC9q/fsAYYOBQ4eNI7VrasWUOvYsVjH0Lxcc7St2BY7L+9Erxq9lOy2tAiDq1eDJyIG3URERETkuArUuzozGXjrLWDmTCArS32ghwcwYQLw5puAr2+hXzc+JR7BvsGG+25ubniu9XPoX7u/VbPbDlsNnogMGHQTERERkcPKr3f1HVmb0WzuMODUKePGpk2BJUuAFi0K/Xr6vttfH/wa49uPR9tKbQ3bQvxClBsRkSkG3URERETksHLrXR3qnYDe+yei2c/zzdPEkyYBr7wCeHkV+rWuJFzB3J1zDX23ZQ13g/AGCPTJ9uJERCYYdBMRERGRw7LUu7rW6b9w14oRKB0faRxs1w5YvBho0KDQr5Gly8Lvx3/HN4e+QVpmmmG8faX28HTnn9NElDf+liAiIiIih2Xau9ovORa9V49FswNfG7anefnDe8ZUtTq5rOMupMj4SHy04yOcjD1pfM2Asnix7YvF7rtNRK6BQTcREREROSx97+pqe37Bg+ufR6nE64ZtByPuRPKcz9D2sRqFft6MrAz8cvQX/HDkB+Xfwg1uuLvO3RjcdDB8PQtffI2IXBODbiIiIiJyWO2rX0OlU8+j5oFfDWOJXsH4usksZD05FE8/4Fak5/1i3xdYfnK54X7FwIoY3XY06ofXt8pxE5HrYNBNRERERI5HpwO+/hreY8ei5s2bhuFD1e/B+ocXoGnfCoXqXZ2913epiPtwu8walCqdigcb3I/HGj8Gbw82wiaiwmPQTURERESOJTISGDEC+Osv41h4ODB3Lho//DAauxUuuy0B9yefJWPbRj9Dr++ow2GA1xi0bBWBxxrUgnfhl4MTESkYdBMRERG5oOyZXSlIJuujC5MdLvHjyMoCFi5UW37dvm0cHzgQ+OgjICys0K+flJ6EN37+Cr+e3oV7qsxDSGCAYVuFhA44shnY3hTo0qU43yURuTIG3UREREQuRgLdRYuADRtgyOxKBXApSHb4MPD00yUTeBfqOE6eBIYPBzZvNj5BxYrAggXA3XcX6fV3X96N+f/Ox/bTN5DuBZwO+BJt8Lxhu/T+luOSCwIMuomoqBh0ExEREbkYySxLoFulihpY6iUkqOONGpVMkFmg47gjA5g9G3jzTSAlxbiTROQzZgDBwYV+3fiUeCzauwgbL2xU7qemAt7uPghG5Rz7yoUAycATERUVg24iIiIiFyOZW8ngmga69sjs5nccJ385iC4vDQX27DFurFFDTY93717o19PpdEqg/dmez5CQlmAYr+bbDKVPjELdWmVzPEZmsVevXuiXIiIyYNBNRERE5GKU6tylLG8rycxubsfhkZGKgSfeQ78fpwH/9ciGFEcbMwZ45x0gwLjuuqCiE6Mxf/d8/Hv1X8NYKe9SGN58ODzKd8cnB92UDHv2jHtGhrrGnIioqBh0ExEREbkYKVYma6ctKcnMrqXjqHhpJ+5dPhQR0UeNgw0aAIsXA+3aFel1MrIyMH7NeMQmxxrGOlbuiBGtRqC0b2mkVQKOHDFfWy7vgwTcXbuqRd2IiIqKQTcRERGRi5HMrRQrs3dm1/Q4QnyT0P2fN9Buxxy4Qadsz/LwhPtrrwKvvw74+BT5dTzdPfFIw0ew4N8FCPELwXOtnkPbSm0N26VYmywRlzXk+irqcuHBHtXcicj5MOgmIiIicjESSEp1cHtndvXHEf3jegzfMxxlE88atkVVbokyvy2Be8smhX7etMw0ZGZlws/LzzDWt1ZfpGSkoHfN3gjwzjk9XQJrWcfOKuVEZG0MuomIiIhcjFYyu97J8Xh238vw2PSZYSzd0xeRw99G5Vlj4eVf+D9VD14/iHm75qFp2aZ4vo2x/Zebmxvur3+/1Y6diKigGHQTERERuSC7Z3ZXrACefRYely8bxzp1gtfnn6NmnTqFfrqE1AQs2bcEa8+tVe5fvX0VXat1RcOIhtY8aiKiQmPQTUREREQlJzoaGD0a+P5745jMb58+HRgxAnB3L3QbsA3nN2DxvsWIT403jNcPq49g38L38CYisjYG3URERERkezodsHQp8MILwI0bxvE+fYBPPwWqVCn0U15NuKq0Adt/fb9hzN/LH082fRJ9avVRppQTEdkbg24iIiIisq0rV4CRI4Hly41jZcoAc+YAgwerPbgL2QLs12O/YumRpUrRNL07Kt+BZ1o+o1QoJyLSCgbdRERERGS77Lb01x4/Hog3Tv3GAw8A8+YB5coV6Wm3Rm7F/w7+z3A/3D8cI1uNROuKra1x1EREVsWgm4iIiIis7+xZ4JlngHXrjGNlywKffKIG3cXQuWpn/HnqTxy/cRwD6g3AwMYD4evpW/xjJiKyAQbdRERERFaUlgZs325sxSVxprTiatsWriEzE5g7F3j9dSApyTg+ZAgwezYQElLoQmknYk6gXlg9w5is1X6x7YvK1PIaZWpY8+iJiKyOQTcRERGRFQPuRYuADRsAT0+1KPehQ8C+fcDhw8A998C5HTsGDBumXnXQkwJpUihNCqYV0uVbl7Hg3wU4cP0Apt05DY0iGhm2VQqqZK2jJiKyqcL1ZCAiIiKiXEmsKQG3xJl16wIVK6pf5f7GjcCJE3BO6enAe+8BzZqZB9zPP69ebShkwC0Z7O8PfY9Rq0YpAbeQKuWZWZnWPnIiIptjppuIiIjISmRKuWS4AwPNx+W+jB89CnTvDueyd6+a3d5vbNuF2rXVAmqdOhX66fZf248Fuxfgyu0rZoXSnmz2JDzcPax11EREJYZBNxEREZGVyBpumVJuiYzHxcF5pKQAU6YAM2ao67iFh4daqfzNNwE/v0I9XVxKHBbvXYwNFzYYxjzcPJRCaY82epSF0ojIYTHoJiIiIrISKZoma7gtuX0bKF0azmHLFjW7ffKkcaxJE2DJEqBly0IXSlt1ehW+PvA1EtMTDeP1w+rjudbPoVrpatY8ciKiEsegm4iIiMhKpEq5FE1LSDCfYi73MzKABg3g2OQbee01te2X9OAW3t7ApEnAyy+r/y6CLZFbDAF3Ke9SeKrZU+hZo6dSpZyIyNEx6CYiIiKykvbt1bphptXLJcMtAXfXrmpRNYf199/A008DkZHGMemDJtntYlxNkMB6ZKuRePGvF9Glahcl4A72DbbOMRMRaQCDbiIiIiIrkUSvxKWNGhn7dFevbuzT7ZBrum/eBMaNA7780jgm67WlWvmLL6rruAsxlXz9+fUI8w9Dk7JNDOOVgyvjs7s+Q3hAuLWPnojI7hh0ExEREVk58O7SRb2ZysqC4/n1V7Xt17VrxjEpvy7NyGvUKNRTRcZHKlXJD0cfRoVSFTCv3zx4eXgZtjPgJiJnxaCbiIiIiMxJin7UKODnn41jQUHAzJnA8OEyJ7zAT5WcnowfDv+A30/8jkydWuVc2oHtvLwTHat0tMXRExFpCoNuIiIiIlJJcbRvvgHGjAFiY43jd90FLFgAVKpUiKfSKQXSPt/3OWKTjc9VvlR5jGg5Ai0rFK7KORGRo2LQTURERETAxYvAiBHAqlXGsbAw4OOPgUcfLVR2+2L8RSz8dyEORh00jHm5e+HBBg8qN2+PolU5JyJyRAy6iYiIiFyZLDb/7DO15Ze0BNOTQFsC7vDCrbVed3Yd5u6aa5hKLlpXaI1nWj6DcqXKWfPIiYgcAoNuIiIiIld16pRabn3jRuNYhQrqVPJ77inSU9YPr6/219YBZQPKKsF2m4ptrHfMREQOhkE3ERERkauRxuFz5gCTJgEpKcZxKZI2YwZQunSBnyotM81suniFwAp4tOGjSqabU8mJiBh0ExEREbmWQ4eAoUOBf/81jkkzcZli3qNHgZ8mMS0R/9v/Pf7cvwuto+ciJsoHZcuqPcnva/+I0jqNiIgYdBMRERG5hrQ04L33gKlT1Uy3kGngo0cD774LBAQUuCr5unPrsGTvlzh0Mh4xMcCN2z+jUdYgJZ7ftw84fFidtc7Am4iIQTcRERGR89u1S81uHzliHKtfH1i8GGjfvsBPcyrmFD7d8ylOxJxQgm25Bfh5o3KgLyr+t4/UYtuwAWjUCOjSxfrfChGRo2HQTURERA6TqN2+Hdi6Fbh+HYapzBIzMqOai6QkYPJk4MMP1SrlwtMTmDgReOMNwMenQE8TnxKPrw98jTVn10AnFdKgtvEOT7kDvQKHIgARhn0DA9WXkM+JQTcREYNuIiIicpCAe9EiNYMqAV2pUrA4lZmBuQl5s6Qw2pkzxrEWLdTsdrNmBXqKzKxMrDy1Et8e+haJ6YmG8cpBlREYNwIBaU1haVK6fD7y/hMREYNuIiIicgASSEsMWaWKmkmFhanMElgXJDB3erduqT23P/3UOCYZ7bfeAsaPV9+cAopLicNXB75Camaqct/P0w+DGg9C/zr9Mf2AJw5dtfy427fV2mxERMSgm4iIiByAZK4lVjQNuLNPZRb5BeZOP9155UpgxAjg0iXjmKT6Jbtdt26hny7UPxQPNXgI3xz6BndWvxNDmg5BGb8yhqeVCxry/mZ/v6VOm2wnIiIG3UREROQAZKqyZK6Rx1TmggTmTht037gBjBkDfPutcUyqkb//PvDcc4C7e75PkZKRgt+P/467694Nfy9/w/h99e9D8/LNUSe0jtn+MrNAZhCYziyQDLcE3F27Fqo+GxGRU2PQTURERJona7NlqnheU5kLEpg7HZ0O+OknYNQoIDraON6zp9p3u1q1AjyFDpsubMIX+79ATHIMktKT8FTzpwzbvT28cwTcyri3OmVfZhDo19DL5+Cya+iJiHLBoJuIiIg0ryBTmSXwyy8wdypXrgDPPw8sW2YcK11arVQ+ZIjag7sALcAW7V2EYzeOGcZWnl6Jhxo+hFLeuVzBMCGBtcwecNoZBEREVsCgm4iIiDSvoFOZXWKNsWS3v/gCGDcOiI83jt93H/DJJ0D58vk+RWxyrNICbN25dWbjrSu0xrDmwwoUcBMRUcEw6CYiIiLNK8hUZpdYY3zuHPDMM8DatcaxiAhg3jzgwQfzzW6nZaZh2fFl+OnoT8oabr2KgRUxvMVwtKrQClrDNnBE5OgYdBMREZFDyG8qs1OvMc7MVLPYr74KJCUZx594Apg9GwgNLdDa7Ql/T8DZuLOGsQCvAAxsPBD9aveDp7unw/ZnJyLSMu39diUiIiIqIqdcY3zsGDB8OLBtm3GscmW1D3ffvgV+Gjc3N3Sr3g1n952FG9zQt1ZfDGoyCEE+QXDk/uxO9VkTkVNi0E1ERESkRenpwIwZwJQpaspXb+RItRVYUFC+67a93L0Q6GOMVu+qcxci4yNxb917UbV0VWidS7eBIyKnwaCbiIiISGtk/vTQocD+/caxWrWAxYuBzp3zfGhqRip+O/4bfjn2C7pX646RrUcatskU8hfbvghH4ZJt4IjI6bgXdMf27dvjsCyeISIiIiLbSEkBXnsNaN3aGHC7uwMTJgAHD+YZcMua7fXn1mPEihH49tC3SqG0VadXKZltRyVF06QYniUyLtuJiJwm033+/Hm0bNkSL730EiZPngxfX1/bHhkRERGRK5E128OGAcePG8caN1az2xKE5+FI1BEs3rcYp2JPGcbc3dzRv3Z/lPEtA2fuz07Frwbfpg3QsSPAP++J7Bx0nzhxAq+++iqmT5+On376CQsWLECPHj1sdFhERERELkJStq+/Dsydq/bgFl5ewBtvABMn5lme+2rCVXx14Ctsvbg1R7/toc2HolJQJTgyl2gDZ+dq8PL+RkXJ3/qsBk9k96A7KCgIn3zyCZ588kk8++yz6N27NwYOHIjZs2cjPDzcZgdIRERE5LR9o6XftkQ6588bxyTtKNltKc2dhx+P/IjvD3+PjKwMw1i14GoY1mIYmpVrBmfg1G3gNFQNXq7xbNzIavBEmimk1rp1a+zevRtz587FpEmTsGLFClSWthUW2lIcOHDAWsdJRERE5Dx9o+PigJdeApYsMY75+QHvvAOMGQN4eOT7FME+wYaAW/49uMlg9KzZU5lW7kycsg2cxqrBy6nHavBEGqtenpGRgejoaKSmpiI0NFS5ERERETk7q/SNXrYMGDUKuHrVOCYP+vxztUJ5LkXSUjNT4etpXHTbo0YPrD6zGs3LNccDDR6Av5d/sb8/cm6sBk9kH4W+FLp27Vo0atQI77//PkaMGKGs9V6/fr3FW1HJNPZq1aopxdratm2LXbt2FehxP/zwg5JhHzBgQJFfm4iIiKg4faNzFRWF4BEj4P7AA8aAWx64YAHwzz+5BtzHoo9hwpoJmL97vtm4h7sHZvWahcFNBzPgpgJhNXgijQfdktl+/PHHlbXc/v7+2LZtGz7++GMEZv+vTjEtXboU48aNw5tvvom9e/eiadOmymtGSYWHfKqrjx8/Hp06dbLq8RAREREVK1MoxdG++QZuDRvCb/ly43i/fsCRI8Czz6ptwbK5knAF0zZPw8trX8aJmBNYf349TseeNttHkg1EBSVr4aUInczMMJWczGrwRJoIuuvWrYtly5YpGe49e/agjRT5sAEpzPb000/jqaeeQoMGDbBw4UIlyF9iuuYpm8zMTAwaNAhTpkxBjRo1bHJcRERERIXOFF68CNx1FzB4MNxiY5UhnSzL++YbYMUKwEJdnPiUeCz8dyGe+/M5bLu0zTBeOagy0jPTrfsNkUuR4nNS9T0yUq1WfvkycPKkJNfUFQ6sBk9k5zXd7dq1w/z585Vp37aSlpamBPTSmkzP3d1daU22XRZR5eLtt99GREQEhg0bhs2bN+f5GrIOXW56t27dUr5mZWUpN62SY5P1XFo+RnI9PC9Ji3heki116ADs32+5b3RmprpdOfXk/z77DG4TJ8LNJK2YfPfd8Fq4EO7lyqkZcH2LMAApGSlYdnwZfj3+q/JvvdK+pTGw0UD0rNFTmVLOc5uKSpZASCv4hg3VtvAykbR69Sy0bq3DHXdkKdt5epEWZDnIf8sLenwFDrpXrlwJW7tx44aStS6b7TKx3D9+/LjFx2zZsgWLFy/GfvkvYAFMmzZNyYhbmj6fkmL8D5wWP9D4+Hjl5JMLEURawPOStIjnJdlSzZpA795qxXKJpX19AfnzQQJuGZftMTvPIfill+BtkjDILFsW8VOn4nr79giWpEK2ZXM7r+7Et8e/RVxKnGHMx9MH/ar3Q99qfZUCajE3Ykr0eyXnVa+eejP9nRkXx9+ZpB1ZDvLf8oTsazWsWb1cS9/k4MGDsWjRIoSFhRXoMZJFlzXjppluaXkmvcalF7mWTzxZtyXHqeUTj1wLz0vSIp6Xjt2Oa8cOYwYuIkLNHLdrp61+zIMG5XKcrTPhPX8O3CZPhpvJhXzdU0/BbcYMBAUHIzU62uK5WTqlNJJ0SfD28VZafvWu2RuPNnwUZfzK2OE7JFfC35mkRVkOcl5K4W+HC7olcPbw8MD1bFVI5H45mYaVzZkzZ5QCanfffXeOFL+np6dSWb2mXHI24ePjo9yykw9Tyx+okBPPEY6TXAvPS9IinpeOGXAvXmy5/7XUGitw/+sSIH9jybpYuRlIk+5uQ4Hdu41jsiRPppj37Aml3Nl/f0TKeZmpy4SXh5dh185VO+P3E78jzD8MQ5oOQcWgiiX6PZFr4+9M0iI3BzgvC3psmgq6vb290bJlS6xbt87Q9kuCaLk/SvpZZlOvXj0ckv8im3jjjTeUDPhHH32kZLCJiIhcIWCVmczSrkquW8sqLalCLEWRtBKolkj/a3u9+dOmAe+9B6T/V+RMKoq/8II6lq3U+eXbl7HoxCLl35O6TDL743LqnVMNfbid4TMlIiINBt1Cpn4PGTIErVq1Uiqkz5kzB4mJiUo1c/HEE0+gYsWKytpsSedLz3BTpUuXVr5mHyciInJGEpwtWmQ5QyzJVy1liIvb/1pzQbdktaUqlWkCoG5dNWWfrffSjaQb+Pbgt1h5fCW8vL2UIPto9FE0CG9g2Mc04HaGz5SIiDQadD/yyCNKUbPJkyfj2rVraNasGf766y9DcbXIyEhNTzEgIiIqSQ6bIbZG/2t7SUoC3noLmDXLWOrZwwN45RVg0iR1/vl/ElIT8PPRn/HHyT+QlpmmFAUSIX4huJ1226k/UyIi0mjQLWQquaXp5GKD/NcmD19++aWNjoqIiEh7HDJDbIFcW8+2Ysys/3X16tCGjRuB4cOB06eNY82aAUuWAM2bG4ZSM1KVQFsC7sT0RMO4v5c/BjYZiHvr3atUJ3fmz5SIiDQcdBMREZETZojzILOxZfq0pf7XGRk5ZmuXvFu31Ez2woXGMSnMOnkyMGEC4GUsirbt4jYs/HchbqbcNIx5uXuhX+1+6BrWFTUq1chz1p6zfKZERKRi0E1EROTAHCZDnA8pECbrlU3XMcvxS8AtVcJlu92sWgWMGAFcvGgckx5hsnZb3+zYhI+HjyHgdoMbulfvjkGNByHULxRR2fpzO/NnSkREKgbdREREDkzzGeICksJgUiBM1ivrK3ZLcGnXit0xMcDYscD//mccCwhQq5U/95yyjlvWaCelJyHAO8CwS4vyLdAovBGCfILweJPHUTm4sllbU1f5TImISMWgm4iIyIFpOkNcSBJYy1plu69XlmJnP/8sRWYA08x0jx5qWXHpvy2tuaMO4+sDX8PDzUNp9yUVyYV8ndJtCrw9inalwJk+UyIiYtBNRETk0DSZIXZkV68Czz8P/PabcSw4GJg9G5D2pW5uOBVzSgm291/fb9hl37V9SoZbr6gBt/JYfqZERE6FQTcREZGD00yG2JFJdls6oIwbB8TFGcfvvReYPx+oUAHn487jm4PfYOflnWYPrRxUGZ7u1v2Tip8pEZHzYNBNROQg0tLU/r36zJcUW2Lmi8gKzp9XC6X9/bdxLDwcmDcPeOghXLl9Fd9unYHNkZuhg9pnW0T4R2Bg44HoVr0b3N1yr0ZORESujUE3EZGDBNyylNR0jadUN5ZiS7L2U6aiMvAmKiQpbCZZ7IkTgURjL208/jjw4YdAWBh+P/47luxfgiydsQhaiF8IHm34KHrW7Gn1DDcRETkf/peCiMgBSIZbAu4qVXJWM5ZxWfvJaahEhXDiBDB8OLBli3GsUiW1D3f//oahWiG1DAG3VCN/qMFDSr/t4qzZJiIi18Kgm4jIAciUcslwmwbcQu7LuGxn0E1UAFICfOZM4K23gNRU4/iIEYh7+zXEe+tQ1WT3hhEN0blKZ1QJroJ76t4DPy8/exw1ERE5MAbdREQOQNZwy5RyS2RcthNRPvbvB4YNA/buNY7VrIlbCz/Cr2HRWLFhnFIUbXbv2Yb2X2LCHRPsc7xEROQUGHQTETkAKZoma7gtkf690k6IiHIhGe133gE++EDNdAt3d9we8xx+e7Qplp//EinRKcrw6ZunsevyLrSt1Na+x0xERE6DQTcRkQOQKuVSNE3WcGdf0y0xhGwnolwKIkh2+9gxw1Bik/pYPuUxLMs4jKTTvxvGvdy90KdWH9QOrW2ngyUiImfEoJuIyAFIWzCpUm5avVwy3BJwd+2qbiciE1KN/PXXgY8/Vntwy5CfB/54eQCWNXBHYvIuw65SgbxXjV54qOFDCPMPs+NBExGRM2LQTUTkAKQdmLQFkyrl+j7dMqWcfbqJLFi3Tv2BOXfOONaqFeZP7IRNaaeATHVIemv3qN4DjzR6BBEBESXeBlCS8PqfZ1lCwp9nIiLnxKCbiMhByB/iUqGcVcqJchEXB4wfDyxebBzz9VXXc48Zg/tvXcCm1WOUYPvO6nfi4YYPo1ypciV+mBJwyyGazlyRmg2yhERmtMj1AgbeRETOg0E3EREROb7ly4GRI4ErV5DoBfxRF6hWrRnazf4RqK2u0a4ZUhPPtHgGrSu2tkuwrbdjhxpwV6mSs0aDjMuMFl5cIyJyHgy6iYiIyHFFRwMvvAAsXYrb3sDyxsDyhp5KsbTKjduhTa2acDfZ/e66d8Petm1TM9ymAbeQ+zIuU84ZdBMROQ8G3UREROR4pDja998DL76IhIQYLGuiZreTK4QDTZoAfn64fPsKTsWcQt2wutCSqCh1SrklMi5rvImIyHkw6CYiIiLHKvx1+TLw7LOIX7MCv9UH/qwNpPh7AQ0bAhUrwcPdA92rd8dDDR5C+cDy0JqICHUNtyXSlUCKJBIRkfNg0E1ERET5BtyLFmmg8Jdktz//HLrxL+HLGglYcS+Q5gGgQnllIbSnX4BSjVxafxWnGrmtLzB06KC+d7KGO/uabmkDKK9FRETOg0E3ERER5UkCULsX/jpzRo3u16+HG4Ab/kCavw/QuDE8K1ZC75q98WCDB4vdZ7skLjC0awccOWL+GpLhloC7a1c1uCciIufBoJuIiIjyJBlfuxX+yszElQ/fRsSb0+GZlGIYfqTBQ9jZJQW969+N++vfj1D/UIe5wCBBuwTv8lz6bLpMKWefbiIi58Sgm4iIiPIkQaE9Cn9d2LUGP344DJvdLmJUeaDXGQBVqwKffYYqvXrh6/Qk+Hv5O+QFBgms5XlYpZyIyPkx6CYiIqI8yZrmkiz8dfraUfz42YvYfvIfwE2njP3UELiz3/PwmPq+4QqAtQNue15gICIi58Wgm4iIyMkVtzCY7Gvrwl86nQ6How7jp7UfYd/mH4FbCYZtQT6B6PnAeGQOmAgPD2+nusBARETOj0E3ERGRE7NGYTAJzmVfWxT+kmB795Xd+PHAdzixY4VaME1NbiMkxQ33178PvV9dDN9SpVESSuICAxERuRYG3URERE7MGoXBbFn462bKTUxbNh4Z+/cCiYnKWNlE4IGU6rjzve/h3aotSpItLzAQEZFrYtBNRETkxKxVGMxahb8ks+3m5maI/EMmTkK3PZuxpgZQNR546KQnOg6ZDI9XJgJeXihprCxORETWxqCbiIjIiWmlMFhiWiJWnlqJjRc2YlavWfBZtwF45hkgMhIPlQLaXQJaV24Lt9+XAA0awJ5YWZyIiKyJQTcREZETs3dhsJikGPx+4nf8dfovJGckA2npWPtCf/RfuM6wT/ksf5SfMBUYNQrw8LDtAREREZUwBt1EREROzF6FwS7GX8Svx37FhgsbkJGVoQ5evQq3Q4dw9WCaccc771T6bqNGDdscCBERkZ0x6CYiInJiJV0Y7MSNE/j56M/YcXmHcTAlFV6HjuDOLVdw/zGg/G0AwcHArFnA0KGAfo03ERGRE2LQTURE5MRKsjDYqlOrMP/f+cYBHRBwOQr9fjuMu/cloUzKf+P33APMnw9UrGi9FyciItIoBt1EROS0/al37jQGmrK22VUrUJdUYbB2ldrhs72fKdPJQzJ9MGDlWfT+eR/80//bITwcmDsXePhhZretdI5LS7jczvH8thMRUclg0E1ERE4nPR1YvNh8SrUUE5O1zTLVWjK/DDqKLiE1QalE7u3hjfvq32cYL+NXBo81eAQh67ahy5tL4HVL7butGDgQ+OgjICzMPgftZCSgXrQo93N8yBDgq6/4M0BEpAUMuomIyOmcOAFs3AhUqZKzeJgEITLVmu2gCu9qwlUsP7Eca86uQWpmKkp5l0Lf2n3h6+mr7nDyJB5+fj6webPxQTKFfOFC4K677Hbczkgy2HIu53aOy5r9LVv4M0BEpAXu9j4AIiIiazt6VM3umQYbQu7LuEy3pYLR6XQ4Fn0MUzdPxYgVI7Di1Aol4Nb33j4cdViN8KZPB5o2NQ+4pQ/3kSMMuG1AzuG8zvE//+TPABGRVjDTTURETicuTp1Oa4mMy/pWyltmVia2X9qO3479hpOxJ822+Xj4oFfNXri37r0oe/Y6cE87YM8e4w7S/uvzz4Fu3Ur+wF2EnMN5nePHjgFVq+a+nT8DREQlh0E3ERE5ndKlgePHLW+TdllSvZvyzm5PWDMBp2JPmY2H+IXgrtp3oU+tPgiEN/Dee8C0aWqmW7i7A2PGAO+8A/j72+fgXYQURZM12rmd4xER6tfctvNngIio5DDoJiIip9OggbqeVdavZl/PKvGhVHCm3Lm5uaFVhVaGoLt66eoYUG8AOlftDE93T7UsvPTXlnn8pm/6kiVA27b2O3AXIuewFEXL7Rzv358/A0REWsGgm4iInE7dumqRKNPKzZLdk2Cja1e1ZRKpGe2j0UeV4mhPt3waYf7GyuL9a/fH2ZtncU/de9A4orESiCMpCXjjZWDOHHmwuqO8wa++Crz+OuDjY79vxsXIOSxVyHM7x6V6uYzzZ4CIyP4YdBMRkdPx8gKGDVMrNOt7FMt0WvYoVqVnpmNL5Bb8fuJ3nLl5RhmrGFQRTzR9wrBPsG8w3uj8hvFB69cDw4cDZ88ax1q2VLPbTZqU6PGTeg5L26+8zvH8thMRUclg0E1ERE5JggrJdrMtklF8SjxWnV6l9Ni+mXLTbNvuy7sxuMlgNaNt9qB44OWXgc8+M475+gJTpgDjxqlpVNLkOc6fgZLvnS6t3PQXOWTdPS9yEJHgfymJiIic3JnYM/jj5B/YdGET0rPSzbbVDqmtVCG/o8odOQPuFSuAZ58FLl82jnXqpFYmr1OnhI6eyDEC7kWLzKfzS6E7WXcvywBk1gEDbyLXxaCbiIr1R4b8MbFrF6/qE2nVrsu78M6md8zG3OCGOyrfoazXrhdWL2ewfeMGMHo08N13xjGJIj74QA3CpUo5ERlIhlsC7ipVchauk3GZ5s8ZB0Sui0E3ERU54F68GDh1CoiKAgICeFWfSCvF0UyD6GblmiHIJwi3Um8hwCtA6a99V527EBEQYenBwNKlwAsvqIG3Xp8+wMKFuTd+JnJxMqVcMtymAbeQ+zIu2xl0E7kuBt1EVOSr+hs3Ak2bqj2R9XhV37FxTaLjOh17GitOrlCy2KPbjTaMe3t4Y1DjQXB3c0fXal3h6+lr+QlkysobbwBr1hjHypRRK5UPHix9xErguyByTPL7UiaDWCLjsp2IXBeDbiIq1lV9Pz8g3WSJKK/qOy6uSXTcKuR/nvoTJ2JOKGMebh4Y3HQwQvxCDPv1q90v9yeRD3jyZHX9tqkHHwTmzVOvvBQAL9iQK5PzXX5fWiKt2qRyPBG5LgbdRFQkvKrvfLgm0XFEJ0bjr9N/YfWZ1YhPjTfbJpns83HnzYJui+RKyptvAr/+aj4u0cHMmcD99zvlBRteHCBbkHNIznf5fZn996f0RpftROS6GHQTUZHIH6ryx7QlvKrvmLgmUftrtQ9eP6hMId95eSd00Jltr166OvrX7o8u1brkPoVcnDgBvPWWunZb1nDrVaoETJoEPPWU2ujcCS/YONLFAXIsctFGziHTc0v+WygBd9eu6nYicl0MuomoSOSq/f79QHKy+Tiv6jsuzl7QvsX7FuNc3DnDfZlKLlXI+9fpj/ph9XNWITclVQ/ffRf45hsgK8s4Xq4c8NprasQp/bed+IKNo1wcIMcjF2vkR0jOIf0sCrn4zFkURCQYdBNRsa7qm1Yv51V9x8Y1idpyIe4CqgRXMQTS8lWqjs/dNVeZOt6nZh/0rtU7/2nkx48D772ntv8yDbbDwoCJE4GRIwF/f5e4YOMoFwfIMUlgLecPzyEiyo5BNxEV+Y+LYcOALVuMfbp5Vd+xcU2i/aVmpGJz5GasPLUSp2JPYWbPmagbVtewvUvVLvD38ke7Su3g6Z7Pf8KPHFEz29mnkUtF8gkT1LZguUXKTnrBxlEuDhARkXNh0E0lgoVrnJN8djKVrnt3wN3d3kdDxcU1ifZz6dYlpTDa2rNrkZieaBhfdXqVWdDt4+mDjlU65v1kBw8C77wD/Pyz+XhoKPDSS8DzzwNBQS55wcZRLg4QEZFzYdBNNsfCNUSOgWsSS77d17aL25Rg+3B0zqqENUrXQOOIxgV/QvmlKsH2b7+Zj4eHA+PHA889Z7XMtqNesHGUiwNERORcGHSTzbFwDZHj4JrEkrHh/AZ8tuczJKQlmI17e3ijU5VO6FurL+qE1sm7MJre7t1qsP3HHznTui+/DIwYoRZdsCFHuWDjKBcHiIjIuTDoJptP/WbhGiLSInsue5HiZ6YBd6XASuhTqw+6V++OQJ9svyxzs2MH8PbbwKpV5uMVKgCvvKJGwX5+KCmOcMHGUS4OEBGRc2HQ7eJKYuo3C9cQkasue7kYfxGrz6xW2nndUcU4d1mmjVcLroaqpasqwXbD8IYFy2oLiRanTAHWrDEflz7bUo1cKhwWsfWXK9T3cISLA0RE5FwYdLu4kpj6zcI1RORKv/tSMlKwNXKrEmwfu3FMGTsVc8os6JYAe06fOfBw9yj4E2/cqGa2//nHfFy+Cemz/eSTgI8PtIr1PYiIyFUx6HZxJTH1m4VriMgVfvedjj2Nv8/8jY0XNiIpPclsm7T/upl8E2X8yhjGChRwS6svCbIl2N60yXybXLF8/XVg8GCHiFZZ34OIiFwVg24XVxJTv1m4hoic9XefBNf/nPsHa86swdm4szm2Vwmqgt61eqNbtW4FX6utD7Zl+rgE23IFwFStWmqwPWgQ4OUFR8H6HkRE5KoYdLu4kpj6zcI1RFSSa4MlgCup330SdEsVch10hjFfT1+lAnnvmr0LXoHcNNiWwmgSbO/cab6tbl012H7ssYJ9kxrD+h5EROSqHO+/2mRVJTX1m4VriKik1gZLHTFb/O6LSozClYQraFaumWEszD8MLcq3wJ6re1A7pDZ61eyFzlU7w9/Lv3DflATbK1aowfa//5pva9AAmDQJeOghwKMQa8A1hvU9iIjIVTHodnGc+k1Ejia/tcENGwL16lnnd19qRiq2XdyGdefW4cD1A0qrryX3LDFbj/1E0yfwZLMnUa10tcJ/M1lZwO+/q8H2/v3m2xo3BiZPBu6/H3B3h6NjfQ8iInJVDLpdHKd+E5GjyW9t8LZt+Qfdef3ua9dOhzPxx7Hm7BpsidyC5Ixkw+Nik2Ox9+petK7Y2jBWo0yNogXbv/wCvPNOzvRv8+ZqsH3PPU4RbOvxIi8REbkqBt3Eqd9E5FRrg6Oiiva7LzoxGuvPr8cLf6/F1dtXc+xfvlR59KzREzVDahb94DMzgR9/BN59Fzh61Hxbq1ZqsH3XXdJTDM6GF3mJiMhVMegmIiK4+trgLF0Wxq4ei/jUeLNxb3dfVMzoBPfzPZAVVR+nyrohoihBoqRzf/hBDbZPnDDf1rYt8OabQJ8+Thlsm+JFXiIickUMuomIyKnWBnfokH+AfT7uvNm0cHc3d6UA2h8n/4Ab3NCkbBN0rnwnjq5uj60bfS0WbJOsbb6B96VLwOefq7fLl3N+IxJs9+jh9ME2ERGRK2PQTURETrU2uF07IC4u5+Mi4yOVntobzm/AzZSb+GrAVyjtW9qwXSqPB/sEo1v1bogIiMDGjcDWjbkXbJNp0hYztjKFfPVq4NNP1Yrksn7blDxIppF368Zgm4iIyAUw6CYiohLroW2Ndbv5rQ02bWF9M/kmNl3YpKzVPnPzjNnzyPg9de8x3Jfq46YVyPMr2CbbzYLuq1eBxYvVfmaRkeYPkoJoslZ73LhCz6229ftJREREtsWgm4iISqyHdoGmZBdzbXBiajI2X9qMA0cO4GDUQeigM9vu4eaBluVbompw1WIVbJPtSlb777/Vb3r5cvW+qYoVgeHD1ebhlStr9v0kIiIi22HQTUREJdZDO9cp2Vb03ub38O/Ff+Ht4w03k+nbtUNqo1u1bsra7WDf4GIVbPO5HokHE5dAV30J3C5eNNumc3ODrlcfuD/3LNCvn3nq3QHfTyIiIioeBt1ERGQ1hZ6SXQw6nQ7HbxxHvbB6ZsF1h8odlKBbRPhHoGu1rsqtcnDlYhVsc89MR92Tf6DJrkWoe3413LNl0OP8ymF99WFYU3U4Gvavhqf7AN6ejvN+EhERkW0w6CYiIqsp0JTsYgba5+LOYeP5jdgUuQk3km5geo/pqB9e37BPxyodceTSEdzV6C40iGhgFpAXpWDb8T9OocWFz9H1/JcISjFvAp4Fdxyv3hcH2jyNU7X7IcvDC35WzELb+v0kIiIi22PQTUREmu6hLS7fuqwUPtt4YSMuJ5i33pIx06A7yCcIQxoOQUR4RJEDbiQnw/vXXzHyx0Vw37Qxx2Zd1arYVGsYfgx4CuHNK9ksC22r95OIiIhKDoNuIiIqsR7asr2gohOjsSVyixJUZ688ru+t3bxcc6WnttVIhCuVy775Brh5E+6m27y8gHvvVaqXufXogV/HuiMtybZZaGu+n0RERGQfDLqJiKjEemjL9oJaemQpVp9ZbTbmBjc0DG+oFEO7o8odSla72OQAf/hBDbZ37cq5vW5dtQL5E08AERElmoW25vtJRERE9sGgm4iIrCa/HtqW2ltJL+2tF7eiS9UuCPQxpnM7VelkCLql8rgE2rJeO8w/rPgHqtOpAfbnn6sBt0Sypnx9gYcfVoPtjh0BC9PUSyILXZT3k4iIiLSFQTcREVlVXj20TQPt7Ze2K9PHD0cdVnpp+3j4oGfNnoZ9GkU0wpNNn0T7yu1RIbCCdQ4uNladOi7BtqU0dbNmapQ7cCBQurQmstAFeT+JiIhIuxh0ExFRiYhNjsXWyK3YdnEbjkQfUQJtU5sjN5sF3R7uHnigwQPWyWpv3KgG2j//DKSmmm+XNLUE2RJst2hhMattCbPQREREVBAMuomIyKYk0P79xO84duOYxe0VSlVAp6qdlKnjViVR8NdfA4sXA6dO5dwukbFEzTKNPCCgSC/BLDQRERE5ZND9ySefYMaMGbh27RqaNm2KuXPnok2bNhb3XbRoEb7++mscljl+AFq2bImpU6fmuj8REZWs6KToHAF3xcCKSpB9R+U7UK10taK39souMxP46y+U/uQTuK1Zo871NhUSohZEk7XaDRta5zWJiIiIHCnoXrp0KcaNG4eFCxeibdu2mDNnDnr37o0TJ04gwqRqrN6GDRvw2GOPoUOHDvD19cUHH3yAXr164ciRI6hYsaJdvgcicj5pacD27cZpxFK5mtOIVTqdDhfiL2D7xe3K1PEX2r6AOqF1DNs7VO6AxfsWo3JQZUOgXSW4ivUCbREZCSxZotzcL16Eb/bt3burWe0BA9QiaUREREQlxE0nfy1piATarVu3xrx585T7WVlZqFy5Ml544QVMnDgx38dnZmaiTJkyyuOfkGxGPm7duoXg4GDEx8cjKMgKrWdsRN6HqKgo5cKDu7tZ51giu3GV81ICbukmlVvBLInlXC3wlv90nIw5qQTZUhDt6u2rhm0P1n8QQ5oNMdv/asJVlA8sb92DSE8H/vhD/XBWr1bXbpseY7lycHvqKWDYMKBmTeu+NlERuMrvTHIsPC9Ji7Ic5LwsaCypqUx3Wloa9uzZg1dffdUwJm9yjx49sF1STAWQlJSE9PR0hMgUQgtSU1OVm+kbpf9g5aZVcmzyR66Wj5Fcj6ucl9u2qXW4qlTJ2RpKxmWWcufOcHrpmek4GHUQOy/txM7LO3Ez5abF/a7fvp7jnCgbUNZ658nJk3CTrPZXX8EtKspsk87dHbo+fRD34IMIeuwxuOuvhmj4HJWLOjt2qOeZfDsyqatDB6BdO9e7mOPsXOV3JjkWnpekRVkOcl4W9Pg0FXTfuHFDyVSXlXmbJuT+8ePHC/Qcr7zyCipUqKAE6pZMmzYNU6ZMyTEeHR2NlJQUaPkDlSsocvJp+WoPuRZXOS8PHgQqVwayr1jRX9CU7fXqwel9cfgLrL+4Pse4u5s76obUReuyrdGibAuE+IYoV6etKjkZvn/+Cf/vvoO3hYuwmZUqIemxx5D86KPIKFdOOS9TYmM1f15Ksv7vv9XuZR4eaj23mBhg2TLgzBmgVy/Ay8veR0nW4iq/M8mx8LwkLcpykPMyQTIwjhZ0F9f777+PH374QVnnLeu7LZEsuqwZN810y/T18PBwzU8vl/WPcpxaPvHItbjKeXn2rMyikd8XObfJ71rZbqHkhMOKTozGriu70LlKZwT6GFP73ep2w9aorcq/vT280bxcc7Sp2AZtK7ZFkI+Nfn8ePAg3afX17bdwi4sz26STaPTee6EbNgxuPXogwN0dAQ52Xm7apM6Ml4s6+lkU8m3JeSXjMiveFWZRuApHOjfJdfC8JC3KcpDzMreYU9NBd1hYGDw8PHBdqhSZkPvlypXL87EzZ85Ugu61a9eiSZMmue7n4+Oj3LKTD1PLH6iQE88RjpNciyuclxJQSybSEgmOqlWT3yFwWHIV+VTsKey6vEuZOn4+/rwyXsq7FLpV72bYr0WFFuhZo6cSaDcv3xy+njYqSCZv6g8/qGu1d+/Oub1uXWUhvdvgwcqH4+bA56VMKZcMt+myBSH3ZVy2S90Ach6Ocm6Sa+F5SVrk5gDnZUGPTVNBt7e3t9Lya926dRggFWb/u8oh90eNGpXr46ZPn4733nsPq1evRqtWrUrwiInIFUiV8n371Fgw+5puKaYm2x1NakYq9l/brwTau6/strg+W7aZBt2S3R7dbrRtDkiKoO3cCUhWWwLuxETz7X5+wEMPqVXr5A23ZuVzO5JrzFKYzxIZz3YNmoiIiByQpoJuIVO/hwwZogTP0mtbWoYlJibiKalAC2mv+oTSCkzWZgtpETZ58mR89913qFatmtLbW5QqVUq5EREVl7QFO3w49+rlst2RWpt9uf9L/HHyD6Rlpll8jtohtZUp420rtbX9AcfGAv/7nxpsy5ucXbNmaqA9cCBQujScjXw+uc2ikHOsevWSPiIiIiJy+qD7kUceUYqaSSAtAXSzZs3w119/GYqrRUZGmqXxFyxYoFQ9f/DBB82e580338Rbb71V4sdPRcc+yKRVcv5J3NeokfH8lGBIa+dn9tZm/qUycerECezdVx+HD7sZWpsFeAWYBdySwW5WtpkSZLeq0Aohfpa7P1iFdI+QjPb69eqByvxpOXBTMp1Agmw54JYt4cyccRYFERERaTzoFjKVPLfp5FIkzdT58+raQ3K+PsiS/ZE/RiX55Yp9kElb5Pzr0kW9aZVctPp7Uxy8q+/BDf89uIq9SEciOt7+CBs21FAuGsjxt67YGitOrUDrCq2V9dlNyzaFj2fOWhdW++GWddmmQXZysuV9pU/W8OHAww+rZbxdgCPNoiAiIiInCrrJ9UiwIH90WuqDLOP6YIGIzGVmZeJkzEn8e+VfLNy0B+ernEEpf/N94kvthpdnDSVLLz9HVYOr4st7v1QKlFidtF6UIFsamEtpbnlRKf2eG5kycO+9arAtDc9djKPMoiAiIqKiY9BNmiB/bEqWx1IFXxnXBwtEZLRg9wJsvLARielq0bErKWrFaz0vBKA8WiAUtZFlUpTLqsG2FDyTq2b6IFumjssU8txIb6xu3Yy3qlXh6hxhFgUREREVHYNu0gRW8CXKnay/Ph93HnVC65iNJ6UnGQJuId0QdTE10MC3pRJsh6Ee3P/7NX/CWkW5bt5Ur4JJgC23PXvUudC5qVDBPMiWg3CSyuNEREREBcGgmzSBFXyJzPtmX7x1EXuv7lVuR6KPKNPIv3vgO/h7GeeOt6zQUmn31bxcc6UAWlJ4c3z9aQiqeFqxKJdc8dq82RhkHzyotvfKTY0aQOfOxpvcZ5BNRERELoxBN2kCK/iSq4tNjsWBaweU3tn7r+9X7md36PohszZeHat0RKcqneDhrs4pT6sERJ4oZlGuCxfMg+wTJ/Lev149dV60PsiuVKnw3zwRERGRE2PQ7QIcoRUXK/iSq0rPTMe41eNwPj73TgyhfqFoUb4FQv1DzcY93T2LV5RLMtYnTxoDbLlFRuZ+sJKxlr7Z+gC7Y0cgIqJo3zgRERGRi2DQ7eQcpRUXK/iSKwTXJ2JOIDEt0Sxb7eXhZchUm/bNbhTeSAm05VYpqFKBi5/lWZQrM1P9wTcNsqOicn8yLy+gdWs1wO7USf2BDA4u+DdNRERERAy6nZ0jteIqbgVfR8jok33Y49yQNdinY0/j4PWDyu3ojaNKQbSyAWXNgm4hgbVkrZuVa6b0zK4XVk8JxostPV0tdCbBtUwZ37IFiIvLfX8/P/VN0Wey27YF/LP1HyMiIiKiQmHQ7eRcpRWXo2T0yXnPjSxdFs7ePKusuz4UdQiHow4jOSM5x37XE6/j+u3rKFuqrGFscJPB1mnjlZystuzSZ7HlSkNePbKDgtQp4vogu2VL/qAQERERWRmDbifnKq24HCmjT855bpyMOYkJaybkuj3MP0zJYjcp2wSlvM1/KIsccN+6BWzbZgyyd+1Ss9u5CQ83ryzeuLF5Y28iIiIisjoG3U7OVVpxuUpGvzA43d6650ZGVoYyXVwy2HJrU7EN+tXuZ9heK6QWfDx8kJqZqtwP9gk2BNlyK1eqXPGz2VlZamC9bBmwdq2arpex3EglcdPK4nXrsn0XERERUQlj0O3kXKUVl6tk9AuK0+2Lf26kZKTg+I3jOBJ1BEejj+J4zHFlTbaeu5u7WdAta7IfbPAgAr0D0SiiEaoEV7HOlPHUVOCff9RAe/ly4Nq13PetXdsYYEvhs2rVGGQTERER2RmDbifnKq24XCWjX1Ccbl/0c2PThU34/fjvOH3ztLJOOzdXEq5Ap9OZBdaPNnrUOgctxc5WrgR+/139KgdqiUwPNw2yy5e3zusTERERkdUw6HagKcISKEvgXJgpwq7SistVMvoFxen2eZ8bOuhw9fYlXPI9hifbdQRgrNCdnJ6Mk7EnczxPuH+4ksFuEN5A+VoxsKJ1Mtl6ly6pmWzJaK9fr5642fn6Ar16AffeC9x1F3tkExERETkABt0OMkVY2uVKsHz8eOGnCBe3FZcjcJWMvjNOt7f12nN5nn2HUrByxykkJhxHSqljuIHjSHdLQGhtILhmKICWhv0bRjRUvlYOqoyG4Q2V+/I1PCAcVqXTAUePqkG23P791/J+ISHA3XergbYE3AEB1j0OIiIiIrIpBt0ONEVYuvtI8O2KU4Tz4yoZfWebbm/LtecyTVzWY8u67LMh53G7XRZiY9Ul0n4+QMUQoEwZ4HT8MbQzCbolg/3Nfd8g2DcYVpeZCezYYQy0T5+2vJ+sxZYg+7771JNY3hwiIiIickj8S07DOEW4cFwho+9s0+2tsfb8dtpt3Ei6gWqlq5mN/3L0F5yNO6vecQNCQ9WbkGJn9cPqo354fbSq0MrscTJl3KoBt/TOXrfOWAgtOtryfs2bAwMGqMF2kyYsgEZERETkJBh0a5gjTREmbXGU6faFvbCUnpmOc3HnlJ7Y+tvlhMuI8I/A4nsXmz1HvbB6StDtBjdUDa6q3K8bVlcJtisEVrDueuzsJKX+559qoL16NZCYmGOXLHcPxDftjMDHB8DzgXuBqlVtdzxEREREZDcMujXMUaYIk/Y4ynT7/C4sRUbFY82ZXUp/7FOxp5SAW/plZxeVFIX4lHizDHXf2n3RoXIH1AmtAz8vP9hcZKRabVwC7Y0b1ank2aR4+ONA+T44UvNebAu5C/EeIejqAzxdHtDIR0JEREREVsagW8McZYqwq7F14S9Xmm6vv7CkQxZu4TK8EQA/hBguLJWvFoWPd32c6+OlN3bNMjWVwDp7MJ59urnVSSE0OXj9+mz5YbUkLAxX29yDr+IG4GbLHvALUS8AVAIQzPoMRERERE6PQbeDTBHWVy8/dw5IT9fWFGFXYsvCX65CpohHxkfi7M2zuFHtNPZcPYM9Gefh5pmKphiCBnjQcGGpf4dqmBflaQiopchZ3dC6SpAtt+plqiuBd4mRg5KrLfpA+/x5y/vVqKEWQZM12u3b44sPPJTzpK56PcGA9RmIiIiInB+DbgeaIix/78u/tZhVdRXWKPzlirZGbsWeq3twJvYMIm9FGoLoLB3gUx6IiQHc3IHzmafhcdW49rzTHV7ApVEoW6osapSpAX8vYz/tEpOUBKxZowbZf/yhHqwlLVuqQbbcGjY0K4TG+gxERERErotBt8bppwh36gRERQEREYC7u72PynWxorxlOp0OUYlRSvb62u1rGFBvgNl2CbjXnF2T43FyLlepClQLK4+06zXgF98cjRubX1i6s8adKHE3bgArVqiB9t9/qxXIs5MPXK4MSJB9zz1A5cq5Ph3rMxARERG5LgbdRIXAjKXaoutC3AVciL+AczfP4eiVo7iRfgNJGUmGfXrW6IkA7wDDfclSC6kkXimoEmqF1FLWYsu43Ez3tRtZu6EvhLZ5M5CVZflD7ttXbevVvz9QunSBnpr1GYiIiIhcF4NuokJwpYylZK9N22pdTbiK1/55TemJbbpPWmoavH28zfaVKuONIhoZ7ksV8dohtZXiZj6ePtAEKYS2f79xffbBg7l/6BJkS0a7WzfA19dpW7gRERERkfUx6CYqBGfMWKZmpOLirYu4GH9RKXAmGWz52qdWHzzY4EHDfiF+IYhJsryeOcw/TMlYVy9dXSluJn2xTclj5WZ3UoVQstgSZEtWW9p8WVKnjhpoSzG0tm2LvabDUVq4EREREZH1MegmcrGM5c5LO3Eo6pASZF+6dUnpcW2JBN6mJEMtFcOlWrgE1ZK1rhxUGf5p/qhWsRrctVpsQD6g1avVIFvWad+8aXk/Ca71hdDq1XPJFm7kuhylFSIREZEjYtBN5EQZy7TMNGUa+OWEy7iScAWxybF4puUzZvvsvLzTYlEzU76elqdQz+w10+x+VlYWoqTCn9bIMUmlccloS+Xx1NSc+0gfvu7djYXQKlSwx5ES2R1bIRIREdkWg24iB8xYJqUn4UjUEVy9fVUJri/fUoPs6KRo6KAz2/fxJo+btdqS7LSejMt9KW4m2evKwZVRJbgKwv3DzdZoa56szz52DFi5Us1oyxURGctO1gT066cG2lIQLTjYHkdLpClshUhERGRbDLqJNEiC6uu3r+N64nUlc924bGOl4reebHt709sFei4Jxk0fe0eVO1AzpKYSaJfxLWPT4NqmU1avXAHWrVMz2WvXAlevWt6vfHljITRZA+CjkUJuFnCKL9kDWyESERHZFoNuIjs6Fn1MKVwmQbT0t5YgW74mpCWY7Tek6RCzwLlcqXI5nivAKwAVAyuiYlBFVAisYPi3aWZbRAREKDeHmrIqTxYbC+zZYwyyjxzJff/69Y3rs1u1cojm9pziS/bCVohERES2xaCbyMqkjZZkqmWqd3RitPI1KjEKXu5eGNRkkNm+Xx34Ckei8wge/yPZblN+Xn4Y3GQwQv1CUT6wvBJgB/kEaWpKuKUpqx6ZaciIvonTy2NwJCsWzavGqsG03GJiLP9bbjLPNS/+/moqrkcP4K671OrjDoZTfJ2TI8xecKVWiERERPbAoJuokAG1MA1uJVsthcmkf7XcJMhOyUix2FYre9AtGWfToNsNbsp+ZQPKKtnssqXUrzXL1MzxfA83fBh2I4XJpAp4dDS8zpxR10/HxZkFzBFbYzE5OhZhbjHwS45Vbj5pt43P8XcxXl8y123aqEF2z55Au3baiWCKiFN8nY+jzF5wxlaIREREWsKgm8gkOx2XEoebKTeVqt83k28iJjlG6U0t95V/J8dgeo/pyppoPQmy86sGLuR5MrIylJZbet2qdUOD8AaGIDs8INxse4kFz7llmXP7d2Ki8nCZtB2ay1PXL+6xSZQSEmK8hYYCVasCd96prs0uXRrOhFN8nY+jzF5whlaIREREWsagm5w6kL6ddhvxqfFKMB2fEq/8W75K0NymYhvDvulZ6Xj0l0cL9LwSgNeEMeiWSt963h7eCPMLU7LVEkDLNslmy7/lq4ebh9lzNS/f3BrfqPoXsgTDEkDLTf9v0695BM+2lOnuiWS/UCT7lVG+JvmH4lpaCPwqhKBtX5OAukwZ9av+vvzlr6Hp8raeRswpvs7HUWYvaL0VIhERkaNj0O0IJk2C26xZKCvBlQMFIdaU6aZDohdw2xu47aVDgjeQ4Q60vWpeIOt/DTOwq3wWbnkD8T46ZOZSP6v3OQ+02Ws8/eVvyoB70pDoZaHN1H/8090Qmgzg9TXANeMTV/fQ4aNSOoQnu6FUmjpFvMSjvcxM27+O9LU2zTqHhEBXpgyS/PzgV7Ei3MPCjEFzSAh2nAzBJz+EIqJ6AAKD3MyyfJGRwKhRADQQcGhlGjGn+DofR5q9oIVWiERERM6KQbcjSE+HW7JEe45JwtgUTyDZC0j2BJK81Fuit/q19WUgONW4/8GywNKG/wXY/91kv+yCUoFvz5qPxXgC50uZvHAusWicRxaQnG421u6CBPdAmRSgTLL6NTQJCEmGEmz7Zuj3NH9SXwA1TJYqa54EzybBcY4ss0lQbXazkHnWZWUhISoKfhEROSqEt2gItLvhGlNWrTGNmFN8nQ9nLxAREZFg0O0IypWDrkkTZGRkwNPTs9h51CzokOGmQ7o7lK+SRZassYhINT8ljgWmIMY7E+luOqS5q7dU9yykeMhXHVI8spSvrW76oXOMMaUj+zzd/BKSPdR98jLtSHkE35LQVZVcJgkHa+efApJ61llNqsHd5B0JCo+FV6lbCMrwQHC6O0qny1cPk6/uytey8n02MU89jkkyueP93y0I2icRmgTNEhjLV9N/m37VB9NS6bsEZky40pRVa0wjdqX3y1Vw9gIREREJBt2OYMwYXB/2KF5b/Rp8/XzN1iwr/zP5mqXLwvSe0xHqbyxvtfLUSizZt0TZJoW8ZF9LpJ/z/P7zzca++2cS9l/fn+8hBtW9F51bDDfc99bpcPOHewr07SVNmgSYrK8OiDoMrHtVWf9cyrtUjlugd6DyVVpkZc3vC3eTwmNDsjLxlJu7plpnuTJXmbJqrWnErvJ+uQrOXiAiIiLBoNtBpGem41LCJXineecbUEpRMFOZWZlIzTSZv50LCcizk8JgBZGaYf78coyVAispX/08/ZS+0qZfA7wD4O/lr9yqBlc1e2y9sHr48cEf4evpW+jg2cPdvFAZOT8t9EHmNGKyhLMXiIiISDDodhDubu5KAKwPROV/Mib09+WrfsxUsG+wEtjKNmlHJTcvdy8lQDX9d4hfSI7H3lnjTjQu21h5bdlPfww+nj7w8fAxfJWsc3YL7lpQpO9Vf4xEjtIHmdOIKTecvUBERESMbBxE+cDy+LzX54iIiIB7toJV+elctbNyK4oOlTsU6XFErtQHmdOIiYiIiCg3DLodnBam1hK5eh9kTiMmIiIiotww6HZgWplaS2QvWuqD7CjTiB3lQp2jHCcRERFRfhh0OzCtTK0lshcWMHPOC3WOcpxEREREBcGg24FpZWqtYFaK7IEFzJzzQp2jHCcRERFRQTDodmAlNbU2v4CaWSmyFxYws8+FOltfZNPSBUUiIiKi4mLQ7cBKYmptQQJqZqXIXljArOQv1JXERTYtrdUnIiIiKi4G3Q6sJKbWFiSgZlaK7MlRCpg5y4W6krjIxrX6RERE5EwK1/CZNEUyeTKFNjISOHECuHxZ/Sr3rTW1tiABNbNSRI5BLsTJBTkJkE0V5kJdQX4naOE4iYiIiLSCmW4HVhJTawsSUDMrReQ6a+BL4iIb1+oTERGRM2HQ7eBsPbW2IAE1K0gTuc6FupK4yMa1+kRERORMGHRTngoSUDMrReQ6F+pK6iIb1+oTERGRs2DQTXkqSEDtSFmpgrQ/00K/ca0cB1F2vMhGREREVDgMuilPBQ2oHSErlV+royFDgK++sn+/cfY9Jy1zpItsRERERFrAoJvypZWAurjZ3/xaHUmmbssW+/cbZ99z0jqt/E4gIiIicgRsGUYOQZ/9nTdPzfomJalf5b6My/bitjr680/bt0IqiJJoyURERERERCWDmW5yCNbI/ubX6ujYMaBqVfv3G2ffcyIiIiIi58FMNzkEa2R/ZTq6FHyyRMYjIvLeLo8vCfkdZ0kdBxERERERFR+DbnII1sj+yvpvWbct2XFT+lZH/fvnvb2k+o3nd5zse05ERERE5Dg4vZwcgmR3ZQ13btlfqZ5c3FZHUr1cxu3dCoktmYiIiIiInAeDbnIIkt2VllmS7c2+prug2d+CtDrSQiskrRwHEREREREVH4NucgjWyv7m1+pIK62QtHIcRERERERUPAy6ySEw+0ta7gFPRERERJQbBt3kMJj9JVv2gDedRSH1A2Q5g8yukIs9DLyJiIiIqKgYdBO5KGZ3rdcDnoiIiIgoNwy6iVwQs7uF6wHPoJuIiIiIiopBN5ELYnbXuj3giYiIiIhy457rFiJy6eyuq5Bp9VIJ3xIZl+1EREREREXFoJvIBTG7ayTr2KX1nGT5TRWmBzwRERERUW44vZzIBUn2VtZw55bdlXZsrsJaPeCJiIiIiCxh0E3kgiR7K0XTJJubfU23q2V32QOeiIiIiGyJQTeRC2J21xx7wBMRERGRrTDoJnLBHtvM7hIRERERlQwG3UQu2mOb2V0iIiIiIttj0E3kYNhjm5xBcWdrEBERETkKBt1ETthjm0E3ucJsDSIiIiJHwKCbyMGwxzY5Os7WICIiIlfibu8DIKLCkWm4UmncEhmX7USOPluDiIiIyFkw6CZyMLLuVVp7SVbQlCv22CbHxNkaRERE5Eo4vZzIwbDHNjk6mY0ha7gtkXNZ2tcREREROQsG3UQOhj22ydHJuSpF02R2RvY13ZytQURERM5Gk0H3J598ghkzZuDatWto2rQp5s6dizZt2uS6/08//YRJkybh/PnzqF27Nj744AP069evRI+ZqCS5Uo9ttpZyPpytQURERK5Ec2u6ly5dinHjxv2/vfsPsqqs/wD+WX5uTmA1iAZimaWWWOQvEmKwxmJGpuIfY6jMaSxrAMdy+kFkUZlBjjVUUozmpP8YGROOg0WRYj9EspBmwASnoHAsJGcqVs34db7zOTtLu3hR8btn77l7X6+Z6/U+56w+y364e9/nec7zxKJFi+LBBx8sQ/eMGTNi9+7dDc9fv359zJkzJy677LLYtGlTzJo1q3xsyU90wKDYWuqGG7qnIz/9dPdzvs72PE7rztaYPz/izDMjjjmm+zlf2y4MABhsOoqiKKJGJk+eHOeee27ckJ+qI+LgwYMxYcKEuOKKK2LBggXPOn/27Nnx1FNPxerVqw+1veUtb4lJkybF8uXLn/f/t2fPnjj22GPj3//+d4wePTrqKv8c8sLD2LFjY8iQ2l0roU1VXZe//GV3wG60tdTOnd0hrR1G+zk63i+pK7VJHalL6uhgi9TlC82StZpevnfv3ti4cWN89rOfPdSWf8gXXnhh3J/zSxvI9hwZ7y1Hxu+4446G5//3v/8tH73/oHp+sPmoq+xbXh+pcx9pP1XXZU4pHz688dZS2Z7Hp02r5H9NC/N+SV2pTepIXVJHB1ukLl9o/2oVup944ok4cOBAHH/YRsP5euvWrQ2/Ju/7bnR+tjeyePHi+NKXvvSs9n/84x/xzDPPRJ1/oHkFJYuvzld7aC9V12Xe45uLxDW6cJjtefwId57QxrxfUldqkzpSl9TRwRapy67D9/BthdA9EHIUvffIeI505/T14447rvbTyzs6Osp+1rnwGDzyfukNG3LdhO5gO3ZsxJQpefvG/+65rbouc5GtvN6Wo9qH27GjewX37Bf05v2SulKb1JG6pI4OtkhddnZ2tl7oHjNmTAwdOjQezyWKe8nXJ5xwQsOvyfajOX/kyJHl43D5w6zzDzRl4bVCPxkcgfvmm/uuLp0LmOU2Tw891Hexqyrr8rm2ltq3r/u4vw404v2SulKb1JG6pI46WqAuX2jfavUdjBgxIs4+++y4++67+1zlyNfnH2EPmWzvfX5au3btEc8Hnl8uoZCBOxcwO+20iPHju5/zdbYfYYmFfpd/jXMLqVw0bdu2iMce637O17aWAgCgFdRqpDvl1O9LL700zjnnnHJv7qVLl5ark3/oQx8qj3/wgx+M8ePHl/dmpyuvvDKmT58eX//612PmzJmxYsWK+P3vfx833nhjk78TaF25QFmOcDdawCzb8/hArBres7VUTiPv2ac77+W2TzcAAK2idqE7twDLRc2+8IUvlIuh5dZfa9asObRY2s6dO/sM40+ZMiVuu+22uPrqq2PhwoXxute9rly5fGJ+SgdelAy3OaW8kWw/7I6OSmWwzoBvazAAAFpR7UJ3mj9/fvlo5N6c23qYiy++uHwA/SOvceU93I08+WT3aDMAANBi93QD9ZDTt3M7rsN3QcjX2Z7HAQCAFh3pBpor75fesqXv6uU5wp2B2wJmAADwwgndwLNYwAwAAPqH0A00ZAEzAAD4/3NPNwAAAFRE6AYAAICKCN0AAABQEaEbAAAAKiJ0AwAAQEWEbgAAAKiI0A0AAAAVEboBAACgIkI3AAAAVEToBgAAgIoI3QAAAFCRYVX9h4HG9u6NuP/+iPvui3j88Yjjj4+YOjXi/PMjRoxodu8AAID+JHTDAAfum26KuPfeiGHDIl760ojNmyM2bYrYsiXiIx8RvAEAYDARumEA5Qh3Bu6TTooYNep/7V1d3e0TJ0ZMn97MHgIAAP3JPd0wgHJKeY5w9w7cKV9nex4HAAAGD6EbBlDew51TyhvJ9jwOAAAMHkI3DKBcNO3JJxsfy/Y8DgAADB5CNwygXKV8//7ue7h7y9fZnscBAIDBw0JqMIByW7Bcpbz36uU5wp2B+4ILuo8DAACDh9ANAyi3A8ttwXKV8p59uk8+2T7dAAAwWAndMMAyWOe2YLYGAwCAwU/ohn62d2/3ftw9I9m5OJqRbAAAaE9CN/Rz4L7ppr73bG/eHLFpU/e93Dm1XPAGAID2IXRDP8oR7gzcJ50UMWpU39XJsz3v5TatHAAA2octw6Af5ZTyHOHuHbhTvs72PA4AALQPoRv6Ud7DnVPKG8n2PA4AALQPoRv6US6alvtuN5LteRwAAGgfQjf0o1ylfP/+7nu4e8vX2Z7HAQCA9mEhNehHuS1YrlLee/XyHOHOwH3BBd3HAQCA9iF0Qz/K7cByW7Bcpbxnn+6TT7ZPNwAAtCuhG/pZBuvcFszWYAAAgHu6AQAAoCJCNwAAAFRE6AYAAICKCN0AAABQEaEbAAAAKiJ0AwAAQEWEbgAAAKiI0A0AAAAVEboBAACgIkI3AAAAVEToBgAAgIoI3QAAAFARoRsAAAAqInQDAABARYRuAAAAqIjQDQAAABUZFm2uKIryec+ePVFnBw8ejK6urujs7IwhQ1wroR7UJXWkLqkrtUkdqUvq6GCL1GVPhuzJlEfS9qE7f5hpwoQJze4KAAAALZgpjz322CMe7yieL5a3wVWUv/3tbzFq1Kjo6OiIOl9FyQsDjz76aIwePbrZ3YGSuqSO1CV1pTapI3VJHe1pkbrMKJ2Be9y4cc85It/2I935h3PiiSdGq8iiq3Ph0Z7UJXWkLqkrtUkdqUvqaHQL1OVzjXD3qO8EeQAAAGhxQjcAAABUROhuESNHjoxFixaVz1AX6pI6UpfUldqkjtQldTRykNVl2y+kBgAAAFUx0g0AAAAVEboBAACgIkI3AAAAVETorpFly5bFq1/96ujs7IzJkyfHAw888Jzn/+hHP4rTTz+9PP/MM8+Mn/zkJwPWV9rH0dTlTTfdFNOmTYuXv/zl5ePCCy983jqGgXi/7LFixYro6OiIWbNmVd5H2tPR1ua//vWvmDdvXrzyla8sFww69dRT/T6n6XW5dOnSOO200+IlL3lJTJgwIT7xiU/EM888M2D9ZfD71a9+Fe9617ti3Lhx5e/lO+6443m/5t57742zzjqrfK987WtfG7fccku0CqG7Jn74wx/GVVddVa7S9+CDD8ab3vSmmDFjRuzevbvh+evXr485c+bEZZddFps2bSo/QOZjy5YtA953Bq+jrct8M8y6XLduXdx///3lL+p3vvOd8dhjjw143xm8jrYue/zlL3+JT37yk+WFIahDbe7duzfe8Y53lLW5cuXK2LZtW3nxcvz48QPedwavo63L2267LRYsWFCe//DDD8fNN99c/jcWLlw44H1n8HrqqafKWswLQi/Ejh07YubMmfG2t70t/vCHP8THP/7x+PCHPxw/+9nPoiXk6uU033nnnVfMmzfv0OsDBw4U48aNKxYvXtzw/Pe+973FzJkz+7RNnjy5+OhHP1p5X2kfR1uXh9u/f38xatSo4tZbb62wl7SbF1OXWYtTpkwpvve97xWXXnpp8Z73vGeAeks7Odra/O53v1u85jWvKfbu3TuAvaTdHG1d5rlvf/vb+7RdddVVxdSpUyvvK+0pIopVq1Y95zmf/vSnizPOOKNP2+zZs4sZM2YUrcBIdw3kle6NGzeWU3F7DBkypHydo4WNZHvv81NetTzS+TAQdXm4p59+Ovbt2xeveMUrKuwp7eTF1uWXv/zlGDt2bDk7COpSm3feeWecf/755fTy448/PiZOnBhf/epX48CBAwPYcwazF1OXU6ZMKb+mZwr69u3by1seLrroogHrNwy27DOs2R0g4oknnih/weYv3N7y9datWxt+za5duxqen+3QrLo83Gc+85nyXp3D3yRhIOvyN7/5TTk9MqejQZ1qM8PMPffcE+9///vLUPOnP/0p5s6dW16szKm90Iy6fN/73ld+3Vvf+tacERv79++Pj33sY6aX01S7jpB99uzZE//5z3/K9QfqzEg3UIklS5aUi1atWrWqXLgFmqGrqysuueSS8j7ZMWPGNLs70MfBgwfLGRg33nhjnH322TF79uz43Oc+F8uXL29212hjuT5Lzrj4zne+U94D/uMf/zjuuuuuuOaaa5rdNWhZRrprID8IDh06NB5//PE+7fn6hBNOaPg12X4058NA1GWP66+/vgzdv/jFL+KNb3xjxT2lnRxtXf75z38uF6nKFVJ7B500bNiwcuGqU045ZQB6zmD3Yt4zc8Xy4cOHl1/X4/Wvf305opPTgkeMGFF5vxncXkxdfv7zny8vVuYiVSl3yMlFry6//PLyolBOT4eBdsIRss/o0aNrP8qd/K2pgfylmle477777j4fCvN13uvVSLb3Pj+tXbv2iOfDQNRluu6668qr4WvWrIlzzjlngHpLuzjausxtFTdv3lxOLe95vPvd7z60+mmusA/Nes+cOnVqOaW850JQeuSRR8owLnDTrLrM9VgOD9Y9F4a617yCgXd+q2efZq/kRrcVK1YUI0eOLG655Zbij3/8Y3H55ZcXL3vZy4pdu3aVxy+55JJiwYIFh86/7777imHDhhXXX3998fDDDxeLFi0qhg8fXmzevLmJ3wXtXpdLliwpRowYUaxcubL4+9//fujR1dXVxO+Cdq/Lw1m9nLrU5s6dO8sdHubPn19s27atWL16dTF27NjiK1/5ShO/C9q9LvMzZdblD37wg2L79u3Fz3/+8+KUU04pd86B/tLV1VVs2rSpfGQk/cY3vlH++1//+tfyeNZk1maPrMVjjjmm+NSnPlVmn2XLlhVDhw4t1qxZU7QCobtGvv3tbxcnnXRSGVpye4cNGzYcOjZ9+vTyg2Jvt99+e3HqqaeW5+cS+nfddVcTes1gdzR1+apXvap84zz8kb/AoZnvl70J3dSpNtevX19u+ZmhKLcPu/baa8st7qBZdblv377ii1/8Yhm0Ozs7iwkTJhRz584t/vnPfzap9wxG69ata/iZsacW8zlr8/CvmTRpUlnH+X75/e9/v2gVHfmPZo+2AwAAwGDknm4AAACoiNANAAAAFRG6AQAAoCJCNwAAAFRE6AYAAICKCN0AAABQEaEbAAAAKiJ0AwAAQEWEbgAAAKiI0A0AHPKBD3wgOjs745FHHnnWsSVLlkRHR0esXr26KX0DgFbUURRF0exOAAD1sHv37jj99NNj0qRJcc899xxq37FjR5xxxhlx0UUXxcqVK5vaRwBoJUa6AYBDxo4dG1/72tdi3bp1ceuttx5qnzt3bgwfPjy++c1vNrV/ANBqjHQDAH3kR4Np06bFtm3bYuvWrbF27dqYM2dOfOtb34orrrii2d0DgJYidAMAz/LQQw/Fm9/85pg1a1b8+te/jhNPPDF++9vfxpAhJskBwNEQugGAhhYuXBiLFy+OoUOHxgMPPBBnnXVWs7sEAC3H5WoAoKExY8aUz+PGjYuJEyc2uzsA0JKEbgDgWR599NFYtGhRGbbz36+77rpmdwkAWpLQDQA8y/z588vnn/70p3HxxRfHtddeG9u3b292twCg5QjdAEAfq1atijvvvDOuueaacgG1pUuXxogRI2LevHnN7hoAtBwLqQEAh3R1dcUb3vCGOO644+J3v/tduYhayu3Crrzyyrj99tvLkW8A4IURugGAQzJY33DDDbFhw4Y499xzD7UfOHAgzjvvvNi1a1e5d/eoUaOa2k8AaBWmlwMApY0bN8ayZcti7ty5fQJ3yhHv5cuXl6H76quvblofAaDVGOkGAACAihjpBgAAgIoI3QAAAFARoRsAAAAqInQDAABARYRuAAAAqIjQDQAAABURugEAAKAiQjcAAABUROgGAACAigjdAAAAUBGhGwAAACoidAMAAEBFhG4AAACIavwfKJ4M3aS2fV0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1000x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "# Plot data points\n",
    "plt.scatter(X, Y, color='blue', alpha=0.5, label='Training Data', s=30)\n",
    "\n",
    "# Plot predictions\n",
    "plt.plot(X, y_pred, color='red', linewidth=2, label='Model Prediction')\n",
    "\n",
    "# Plot true function\n",
    "X_true = np.linspace(0, 1, 100)[:, None]\n",
    "Y_true = 0.8 * X_true ** 2 + 0.1\n",
    "plt.plot(X_true, Y_true, color='green', linewidth=2, \n",
    "         linestyle='--', label='True Function', alpha=0.7)\n",
    "\n",
    "plt.xlabel('X', fontsize=12)\n",
    "plt.ylabel('Y', fontsize=12)\n",
    "plt.title('Polynomial Regression Results', fontsize=14, fontweight='bold')\n",
    "plt.legend(fontsize=10)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f4a5b6c7d8",
   "metadata": {},
   "source": [
    "## Key Concepts Summary\n",
    "\n",
    "### Graph Operations\n",
    "\n",
    "1. **`treefy_split(model, *state_types)`**:\n",
    "   - Splits a model into graph definition and state pytrees\n",
    "   - Returns: `(graphdef, state1, state2, ...)`\n",
    "   - Allows independent management of different state types\n",
    "\n",
    "2. **`treefy_merge(graphdef, *states)`**:\n",
    "   - Reconstructs a model from graph definition and states\n",
    "   - Returns: Complete model with all states\n",
    "   - Essential for functional API usage\n",
    "\n",
    "3. **`treefy_states(model, state_type)`**:\n",
    "   - Extracts states of a specific type from a model\n",
    "   - Useful for getting updated states after forward pass\n",
    "\n",
    "### Advantages of Functional API\n",
    "\n",
    "1. **Explicit State Management**: Full control over which states are updated and how\n",
    "2. **JAX Compatibility**: States are explicit function arguments, making JAX transformations straightforward\n",
    "3. **Flexibility**: Separate handling of parameters, hidden states, and custom states\n",
    "4. **Debugging**: Easier to track state changes and debug issues\n",
    "\n",
    "### When to Use Functional API\n",
    "\n",
    "- Custom training loops with complex state management\n",
    "- Implementing advanced optimization algorithms\n",
    "- Fine-grained control over gradient computation\n",
    "- Functional programming style with JAX transformations\n",
    "- Distributed training scenarios"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a5b6c7d8e9",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "Try these exercises to deepen your understanding:\n",
    "\n",
    "1. **Add Momentum to SGD**:\n",
    "   - Create a custom state type for momentum\n",
    "   - Modify the training step to include momentum updates\n",
    "\n",
    "2. **Track More Statistics**:\n",
    "   - Add states to track training loss history\n",
    "   - Add states to track gradient norms\n",
    "\n",
    "3. **Implement Learning Rate Scheduling**:\n",
    "   - Create a state for the current learning rate\n",
    "   - Implement exponential decay or step decay\n",
    "\n",
    "4. **Multi-Task Learning**:\n",
    "   - Modify the MLP to have multiple output heads\n",
    "   - Track separate counters for each task"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b6c7d8e9f0",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Now that you understand the functional API and graph operations, you can:\n",
    "\n",
    "1. **Explore Lifted Transforms**: Learn about higher-level state management with automatic lifting\n",
    "2. **Advanced Optimizers**: Use BrainTools optimizers that handle state management for you\n",
    "3. **Complex Architectures**: Apply these concepts to recurrent networks and spiking neural networks\n",
    "4. **Checkpointing**: Learn how to save and load model states\n",
    "\n",
    "## References\n",
    "\n",
    "- [BrainState Graph API Documentation](https://brainstate.readthedocs.io/en/latest/apis/graph.html)\n",
    "- [JAX Transformations](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)\n",
    "- [Flax Functional API](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/functional_api.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
