{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "intro",
   "metadata": {},
   "source": [
    "# ``Module`` System Protocol\n",
    "\n",
    "The module system is the foundation for building neural networks in BrainState. It provides a clean, object-oriented interface for organizing stateful computations.\n",
    "\n",
    "In this tutorial, you will learn:\n",
    "\n",
    "- 🏗️ The `Module` base class and its role\n",
    "- 🔨 How to create custom modules\n",
    "- 🧩 Module composition and nesting\n",
    "- 🎯 Parameter management and initialization\n",
    "- 📦 Working with module hierarchies\n",
    "\n",
    "## Why Modules?\n",
    "\n",
    "Modules (via `brainstate.nn.Module`) provide:\n",
    "\n",
    "✅ **Automatic state management** - States are tracked automatically  \n",
    "✅ **Clean abstractions** - Encapsulate related computations  \n",
    "✅ **Reusability** - Build once, use everywhere  \n",
    "✅ **Composability** - Combine simple modules into complex systems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "imports",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:10.100523Z",
     "start_time": "2025-10-11T08:24:08.256525Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainstate\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "node_intro",
   "metadata": {},
   "source": [
    "## 1. The Module Base Class\n",
    "\n",
    "`brainstate.nn.Module` is the base class for all modules in BrainState. It provides:\n",
    "\n",
    "- Automatic registration of child modules\n",
    "- State collection and management\n",
    "- Pretty printing and inspection\n",
    "- Integration with JAX transformations\n",
    "\n",
    "### Creating Your First Module\n",
    "\n",
    "The simplest module inherits from `Module` and implements `update()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "first_module",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:10.223290Z",
     "start_time": "2025-10-11T08:24:10.109530Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: [1. 2. 3.]\n",
      "Output: [6. 7. 8.]\n",
      "\n",
      "Module:\n",
      "SimpleModule(\n",
      "  constant=5.0\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "class SimpleModule(brainstate.nn.Module):\n",
    "    \"\"\"A minimal module that adds a constant.\"\"\"\n",
    "    \n",
    "    def __init__(self, constant=1.0):\n",
    "        super().__init__()  # Always call parent __init__\n",
    "        self.constant = constant\n",
    "    \n",
    "    def update(self, x):\n",
    "        return x + self.constant\n",
    "\n",
    "# Create and use the module\n",
    "module = SimpleModule(constant=5.0)\n",
    "result = module(jnp.array([1.0, 2.0, 3.0]))\n",
    "\n",
    "print(\"Input:\", jnp.array([1.0, 2.0, 3.0]))\n",
    "print(\"Output:\", result)\n",
    "print(\"\\nModule:\")\n",
    "print(module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "stateful_module",
   "metadata": {},
   "source": [
    "### Adding States to Modules\n",
    "\n",
    "Modules become powerful when they contain states:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "stateful_example",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:10.291549Z",
     "start_time": "2025-10-11T08:24:10.232214Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial count: 0\n",
      "Call 1: count=1, result=10.0\n",
      "Call 2: count=2, result=20.0\n",
      "Call 3: count=3, result=30.0\n",
      "Call 4: count=4, result=40.0\n",
      "Call 5: count=5, result=50.0\n"
     ]
    }
   ],
   "source": [
    "class Counter(brainstate.nn.Module):\n",
    "    \"\"\"A module that counts how many times it's called.\"\"\"\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Create a state to track the count\n",
    "        self.count = brainstate.ShortTermState(jnp.array(0))\n",
    "    \n",
    "    def update(self, x):\n",
    "        # Increment counter\n",
    "        self.count.value = self.count.value + 1\n",
    "        # Return input with count\n",
    "        return x * self.count.value\n",
    "\n",
    "# Test the counter\n",
    "counter = Counter()\n",
    "print(\"Initial count:\", counter.count.value)\n",
    "\n",
    "for i in range(5):\n",
    "    result = counter(jnp.array(10.0))\n",
    "    print(f\"Call {i+1}: count={counter.count.value}, result={result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "linear_module",
   "metadata": {},
   "source": [
    "## 2. Creating Custom Modules\n",
    "\n",
    "Let's build a complete linear layer from scratch to understand module design:\n",
    "\n",
    "### Example: Custom Linear Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "custom_linear",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:10.684872Z",
     "start_time": "2025-10-11T08:24:10.297350Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Module:\n",
      "Linear(in_features=5, out_features=3, use_bias=True)\n",
      "\n",
      "Weight shape: (5, 3)\n",
      "Bias shape: (3,)\n",
      "\n",
      "Input shape: (5,)\n",
      "Output shape: (3,)\n",
      "Output: [ 0.3793956  -0.9351347  -0.94997764]\n"
     ]
    }
   ],
   "source": [
    "class Linear(brainstate.nn.Module):\n",
    "    \"\"\"A linear transformation: y = W @ x + b\"\"\"\n",
    "    \n",
    "    def __init__(self, in_features, out_features, use_bias=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.use_bias = use_bias\n",
    "        \n",
    "        # Initialize weight with Xavier/Glorot initialization\n",
    "        std = jnp.sqrt(2.0 / (in_features + out_features))\n",
    "        self.weight = brainstate.ParamState(\n",
    "            brainstate.random.randn(in_features, out_features) * std\n",
    "        )\n",
    "        \n",
    "        # Initialize bias to zero\n",
    "        if use_bias:\n",
    "            self.bias = brainstate.ParamState(jnp.zeros(out_features))\n",
    "    \n",
    "    def update(self, x):\n",
    "        \"\"\"Forward pass.\n",
    "        \n",
    "        Args:\n",
    "            x: Input tensor of shape (..., in_features)\n",
    "            \n",
    "        Returns:\n",
    "            Output tensor of shape (..., out_features)\n",
    "        \"\"\"\n",
    "        out = x @ self.weight.value\n",
    "        if self.use_bias:\n",
    "            out = out + self.bias.value\n",
    "        return out\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return f\"Linear(in_features={self.in_features}, out_features={self.out_features}, use_bias={self.use_bias})\"\n",
    "\n",
    "# Create and test the linear layer\n",
    "brainstate.random.seed(42)\n",
    "linear = Linear(in_features=5, out_features=3)\n",
    "\n",
    "# Forward pass\n",
    "x = jnp.ones(5)\n",
    "y = linear(x)\n",
    "\n",
    "print(\"Module:\")\n",
    "print(linear)\n",
    "print(f\"\\nWeight shape: {linear.weight.value.shape}\")\n",
    "print(f\"Bias shape: {linear.bias.value.shape}\")\n",
    "print(f\"\\nInput shape: {x.shape}\")\n",
    "print(f\"Output shape: {y.shape}\")\n",
    "print(f\"Output: {y}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "activation_module",
   "metadata": {},
   "source": [
    "### Example: Custom Activation Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "custom_activation",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:11.047274Z",
     "start_time": "2025-10-11T08:24:10.692880Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Activation: LeakyReLU(negative_slope=0.1)\n",
      "Input:  [-2. -1.  0.  1.  2.]\n",
      "Output: [-0.2 -0.1  0.   1.   2. ]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAArMAAAHWCAYAAABkNgFvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjYsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvq6yFwwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAaMVJREFUeJzt3Qd4VFX6+PGX9ARCCJAQIEDoIEgHBXVFRbGsil2WlSb+bFh3LaxrwboWLOsKrKsQxAI20L8FQVZ0UTFBUBAEKYFQQwKkkz7/5z1hxklIQhKSuXNnvp/nGZgzczNz5sydO++ce857mjgcDocAAAAANhRgdQUAAACA+iKYBQAAgG0RzAIAAMC2CGYBAABgWwSzAAAAsC2CWQAAANgWwSwAAABsi2AWAAAAtkUwCwAAANsimAVQbzt27JAmTZrIc889Z3VV/J6+D4888oglz52QkCATJ0605Lnt+plJTEy0uiqAzyCYBXyEfjnql+Tq1avFl4wcOdK8LuclPDxc+vXrJy+++KKUlZXV6zE18GrWrFm92/KPf/yjCeDqYtiwYeYxZ82aJfX12WefWRawfvfdd+a5MzMzxVs436eqLvfff7+ldXv77bfNPgqg8QV54DkA4ITEx8fLU089Za5nZGSYQOGuu+6S9PR0eeKJJ8TbbdmyRZKTk00A/NZbb8nNN99c72D2lVdeqTKgPXLkiAQFBTVqMDt9+nTzQ6BFixYV7tu8ebMEBFjXN/Loo49K586dK9zWt29fsZLuo7/88ovceeedFW7v1KmTea+Cg4MtqxvgawhmAXi9qKgo+fOf/+wq33TTTdKrVy95+eWXTSATGBgo3uzNN9+U2NhYmTFjhlx55ZXmVHNde3aPJywsTKwSGhoqVrrgggtkyJAhYgfaa2zlewX4IoYZAH5mz549MnnyZGnTpo0JQvr06SNz5sypsE1RUZE89NBDMnjwYBNINm3aVM444wz56quvjvv4DodD/u///k9CQkLkww8/lDPPPFP69+9f5bY9e/aU0aNH1/k1aDAwdOhQycnJkQMHDhwTOGq9dThCy5Yt5dprr5Vdu3aJ1b10GsTq8ARtTy1X5YcffpALL7xQoqOjTZvrcIqXXnrJ3Kc9otorq9xPp1c1Zvb999835a+//vqY5/j3v/9t7tNeQ7Vu3Trz2F26dDHtGhcXZ/aPgwcPuv5GH/eee+4x17UH1PncGpRXN2Z2+/btctVVV5n3ICIiQk499VT59NNPK2yzYsUK8zjvvvuu6WHXHnitwznnnCNbt26VxhxLXLnOziEL3377rdx9990SExNj3oPLLrvMnAGo7PPPPzf7dmRkpDRv3tzsj873VYfG6GvduXOnq62cP16qGzP73//+13zG9Dm15/vSSy+VX3/9tcI2+jr0b7VtnD3kuj9NmjRJ8vPzG6S9ADuiZxbwI2lpaSao0C/EqVOnmi9s/VK+/vrrJTs723VKVK+/9tprMnbsWLnhhhtM0Pj666+bwDMpKUkGDBhQ5eOXlpaaQGjhwoWyaNEiueiii+TQoUPmMTR4cj/1q6fdf/vtN/n73/9er9fiDArcT3lrQPTggw/K1VdfLVOmTDFBiPbe/uEPf5C1a9cec3rcEzRA1eBj7ty5JsC//PLLzVCDv/3tbxW2W7ZsmQl227ZtK3fccYcJKjWY+eSTT0z5xhtvlL1795rt5s+fX+NzarvrmGANEjXgcqfvjf6Acb4X+ngaeGpApM+5YcMGefXVV83/q1atMm2sddb36p133pEXXnhBWrdubf5W95/q9rMRI0aYAOv222+XVq1aybx58+SSSy4xgbYGiO7+8Y9/mGEKf/3rXyUrK0ueeeYZGTdunGm72tC/0eEn7px1rKvbbrvN/Jh4+OGHzT6m4171s6Lt5qSBqO7n2o7Tpk0z+5XuX0uWLJE//elP8sADD5g67d6927SXqmmM9pdffml6l/UHhQasOgxB99vTTjtN1qxZc0wvvu7f+qNCh97o/fpZ1Z7/p59+ul6vGbA9BwCfMHfuXId+pJOTk6vd5vrrr3e0bdvWkZGRUeH2a6+91hEVFeXIz8835ZKSEkdhYWGFbQ4fPuxo06aNY/Lkya7bUlJSzHM+++yzjuLiYsc111zjCA8Pd3zxxReubTIzMx1hYWGO++67r8Lj3X777Y6mTZs6cnNza3xdZ555pqNXr16O9PR0c9m0aZPjnnvuMc970UUXubbbsWOHIzAw0PHEE09U+Pv169c7goKCKtw+YcIE89z1bUt93k6dOjlqY+rUqY4OHTo4ysrKTHnp0qXmsdeuXevaRtu7c+fO5jG1nd05/07deuut5m+rorc//PDDrvLYsWMdsbGx5rGd9u3b5wgICHA8+uijrtuc77m7d955xzzeN99847pN32O9Td/zyrTe2qZOd955p9n2f//7n+u2nJwc8xoTEhIcpaWl5ravvvrKbNe7d+8K+9tLL71kbtf3ribO96mqS3XtUl2dnY81atSoCm1+1113mf1K92Ol/0dGRjpOOeUUx5EjRyo8pvvfVbePOD8z+nxOAwYMMO/VwYMHXbf9/PPP5r0aP3686zZ9Hfq37p9BddlllzlatWpVY1sBvoxhBoCf0O/1Dz74QC6++GJzXXuynBftcdWeJO3lUToGVXsRlWYM0N7VkpISMy7RuU3lYQl6Sll7EXWS0nnnnee6T0+D6ilT7dUrjy3Ke3C1p2vMmDHmtOrxbNq0yfQC6kXHyj777LOml8/9VK0OadC6aq+V+2vT3sbu3bvXaohEQ9M209d5zTXXuIYEnH322aYXTXtnnbRXLyUlxfSMV+49dh9KUBf6nDoEQ0/lO2mvqLaR3uekwzGcCgoKTJtp772q6r2uDd0HNHvD6aef7rpNeyZ1+In2dm7cuLHC9tor7NzflJ5uV9pjXBs6/EJ7mN0v9aV1dG9zrYvurzpkQOlj65kKzZZQeexrfd6rffv2yU8//WSGDeiQDCcdYnLuueeatqxMx4y70zrqsBA9owL4I4YZAH5CT7lrWiU9hayXqriPP9XTwjphSQPJ4uJi1+2VZ40rPd2Zm5trhizoeMHKxo8fb4K6//3vf+aUv55W1VPR1113Xa3qrqdZ//Of/5hAbNu2bWY4gb4e92BCMwZosKyBa1UaevZ4bQKXpUuXmnpqYOc+BvSss84ywb2eFtbT6/qaGnoG/vnnn29+SGi76xhUpdd1iEiPHj1c2+kPFc1SsGDBgmPGH+sPnPrQwO+UU0455vbevXu77nd/rR07dqywnZ7mV4cPH67V82n7NtQEsOPVpaHfK2eQrOPHq2qvL774QvLy8ir86Kupjjp+F/A3BLOAn3DmZNWsABMmTKhyG+0Nck6i0p4i7TnViT/ak6i9tRq0Or/M3WnPro4X1LGOGsxW7rHS+3XCmT6uBrP6v/aYjho1qlZ11y9y9211LOGgQYPMuNN//vOfrtenAaYG1FVlN6hpzGJlzvrr2MWq6FjQ2sxId/a+am9xVXSClga2jUEn9+n7p2OXZ86caX486OSmJ598ssJ2WjdNu6Xvswa62k7alhoM1zePb11Vl43C2ZPfGLS31VvqUld2qCPgSQSzgJ/QU/Q681q/xI8XROrpaJ2Moqfu3XsgdVJMVfS0tJ761AlMOtxAAyj3nKf65asTY3RYgPZGLl682EwKq29KLQ26NSjXmfk6aUh7qrp27Wq+zLXn2L3nsT40F6gzf6rzlLc7nQx1vJ457U376KOPzCl9zWRQmU6M0mBXg1mtu9JJcjW9N3U9ja3PrT3sy5cvN5PJtH3chxhoT57epz2zmr3CvZf7RJ5b20/brjLt5Xfe7ynaa1l5oQcdFqOn9+vD/b3q1q1btdvVtr3c97Wq2ksnstVmKA7gzxgzC/gJDRyvuOIKM27WmZbJnXv6IWeQ6d7TozPLv//++2ofX4MwPVWtPbQ6fKByr57epsGTzsrXIQnueWPr49577zXDH55//nlT1hn3Wm8NzCr3UGnZPdXU8WhqL+2N1lnihYWFFe7TQFzTm+ns85poQK8B7a233mqC2coXDfz1vdDH115mDcJ15nzlwMv9tTiDmtquwqXviY7D1OEFetHT8e7DRKp6n1VVK1fV5bk1vZhmvXDfX7QtdHiLDhk56aSTxFM0+Pzmm28q3Kb1qK5n9nh0PLj+KNSzFDrGuKb3qjbDNDR7hfaI648O97bVz6gOU9G2BFAzemYBH6M5YzWgrEzTO2kKJJ0IpeMZtWdUgwodM6kTfXQcq15XGmhpr6ymUNI0Tzo5afbs2WZ7DUSro6e1NQWVjpHVsXvac+o0cOBA05v53nvvmbGAGsCdCK2LftFrwKnpuDRoefzxx02qJJ1kpHXRoEPrroGlTuzRXlwnDYR1+8o0+LvlllvkueeeM8MxNH+o9mZqeimdqKXtqz3D+ng10V5X/RtNUVUVncCm44A1H6kG4rrMrU7O08BGJ0RpkKM9c5oiS8dNOoNsZ6+uDt3QYFTz6FZHxwnrY+uPDA0m9TW50/dIh33o8BBtj/bt25sAStusMudza9opfU59bK1vVb2GOjlKxwRrwK911TbVYE0fVwN4T64Wpina9KyB/pDTCVU///yzac/6pu7SNtN0W/q4um/oGQft/dXH1eEn+jqd7aU/IDRnrW6nwze0vaqiExq1rYYPH27S5DlTc+mYZ6uWLwZsxep0CgAaRk1pivSya9cus11aWppJ8aTpooKDgx1xcXGOc845x/Hqq69WSDH05JNPmtRCoaGhjoEDBzo++eQTk8rIPd2Qe2oudzNnzjS3//Wvf61w+zPPPGNu18euLU3N1adPnyrvW7FixTGplz744APH6aefblJv6UXTeunr3bx5s2sbfR3VtVPXrl1d233++eeOs846y9G8eXPTVppa6u677z4mfVZl2saaDuy6666rdhtNiRUREWHSKjmtXLnSce6555rUT1r3fv36OV5++WXX/Zpm67bbbnPExMQ4mjRpUqsUVMuWLTP36fbOfcDd7t27TR1atGhh0rNdddVVjr1791b5eI899pijffv2JmWUe5quymmu1LZt2xxXXnmleVxNzTZs2DCzD7lzpuZ67733KtxeVfqq+qaj0zRgmhaudevWpr1Hjx7t2Lp1a7WpuSo/lrOO+r+7jz/+2DFixAiTik73D319mtLMSVPO/elPfzKvX//e+bmp7rV9+eWXjtNOO831eBdffLFj48aNFbZxpubSFHVVtUNVadMAf9BE/7E6oAbgH3Q1q7vuusv0nFaekQ0AQH0QzALwCD3U6LK2eurdipyvAADfxJhZAI1Kx2p+/PHHJoBdv369meEPAEBDoWcWQKPSIQU6g15XttKJVbrgAQAADYVgFgAAALZFnlkAAADYFsEsAAAAbMvvJoDpqkR79+41ydTrujQkAAAAGp+Ogs3JyZF27dodd6EVvwtmNZDt0KGD1dUAAADAcezatUvi4+Nr3MbvglntkXU2ji5L6ImeYF3zPiYmxqNLOPo72t0atLs1vRdFRUWm3XUJXF3iFp7B/m4N2t0/2j07O9t0Pjrjtpr4XTDrHFqggayngtmCggLzXHzoPId2twbtbk2b//bbb+bA36NHDwkK8rvDumXY361Bu/tXuzepxZBQ9gIAAADYFsEsAAAAbItgFgAAALbF4KpqJlSUlJRIaWlpg4wxKS4uNuNMGNvjOXZpd52so2McSRMHAED9EMxWorOC9+3bJ/n5+Q0WGGtgpbnSCFg8x07tHhERYWahh4SEWF0VAABsh2DWjQY/KSkpprdMk/RqcHGigZCzl5feN8+yQ7u7p1TS/a579+5e3YsMAIA3Iph1o4GFBrSa10x7y/wlqPJFdmn38PBwCQ4Olp07d5r9LywszOoqwWZ0/46KijLDorx5XweAxkIwWwV6x+BJ7G84ERrAtmnTxvxPMAvAH/EtCgAAANsimAUAm9MhBg2RfQUA7MjSYHbWrFnSr18/19Kyw4cPl88//7zGv3nvvfekV69eZmzhySefLJ999pnH6uuvduzYYU5f/vTTT+ILdGxqt27d5LvvvmvU58nIyJDY2FjZvXt3oz4P/JuO89+2bZvs2rXLXAcAf2NpMBsfHy//+Mc/5Mcff5TVq1fL2WefLZdeeqls2LChyu01+Bg7dqxcf/31snbtWhkzZoy5/PLLL+LvJk6caNrCm4Nh56Vly5Zy5plnyv/+978Ge40jR46UO++885jbExMTpUWLFhVumz17tnTu3FlGjBjhuu3QoUMybtw486NKt9d9LDc3t8b6vPrqq+Z59W/0dWVmZla4v3Xr1jJ+/Hh5+OGH6/Q6AQCATYLZiy++WC688EKTkqhHjx7yxBNPSLNmzWTVqlVVbv/SSy/J+eefL/fcc4/07t1bHnvsMRk0aJD861//8njdUXdffvmlyeH7zTffmNRnf/zjHyUtLc3jWQ50f9Fg1Z0GsvojatmyZfLJJ5+YOv7f//1fjY+luYh1f/zb3/5W7TaTJk2St956ywTLAADY1bdbM+TblCzxRl6TzUDHe+kQgry8PDPcoCrff/+93H333RVuGz16tCxevLjaxy0sLDQXp+zsbPO/no6rfEpOyxrsOC8NxflYDfmYx3sud9pzfe+995qe0KZNm8p5550nzz//vOk5VEuWLDE/JHQ7zbGr7f/iiy9K165dj6m/XvS9uuGGG8z78e9//1vOOuss+eGHH2TIkCGu59S/18v27dtdf689sjrrWi/Tpk2TBQsWmB8ul1xySa3qWdNrdK9fVds5r+sZAD0lqz+inLf9+uuvpg2SkpJcr+Gf//ynXHTRRfLss8+awLsqd9xxh/l/xYoVVT6/Oumkk8zff/jhh8cE0JXrXdU+aTfOz5DdX4eduO83vrAP2Qn7uzVod8/7dV+23PTWGikoKpWCJqEydljHRn/Oury/lgez69evN8GTLjuqvbKLFi0yAUBV9u/fbwIhd1rW26vz1FNPyfTp04+5XRPV63O60+VPtfE0P6lenC6btUrSc38PiOtM45t6ZMyJaRYqi24+tVbbOr/E3Out9NT3OeecY3oIn3nmGTly5Ig88MADcvXVV8vSpUtdAf7tt99uxiDrqXVtr8suu8wEfpo2yvmY+r/+2Pjzn/9s8qL+97//lZiYGPP4c+bMkQEDBried+7cuXLddddVqJOzXbUO8+bNM7dp8Ky31aae1b1G94DQeZ8z6HZOinHe/vXXX5szAZrf1Xnbt99+a4YWaP2dt+nwAX3tOrTleMM33J+jqrppgKw9vRMmTKjy7/Vv9HUdPHjQ5Jy1M30dWVlZpv1JOebZNtczBQcOHDC5leEZ7O/WoN09a392kdywcJPkFZZ/1y1Zt1vO7hTa6KkAdQXP2rL8qNezZ08zsUh3zPfff9984WvAUV1AW1faA+jem6uBmy6KoEGYjnV0p8GtNp5+Gbh/IWTkFkla9gkEs/XURJrU+otJP9B6qby9jg8dOHCgGZvspIFnx44dTa+pDu/QgNGdBqI6cem3336Tvn37uh5T20cDO+3p/uqrr0yidjVlyhS5+eab5YUXXpDQ0FBZs2aN6WX96KOPKrSljpPVOuqXrh6EBg8ebHpf9f7a1LO612ja6uh43Mr3abCsnLfrJJn27dtX2E4DAH297rfpde1J1h89x3sP3J+jqm31+XQfr+5x9HZ9Xa1atbL9ogn6JaPvg36++JLxXJvrj0Ft98r7MRoX+7s1aHfPyTpSLPe8vUrS84pNuW9cU5n556HSNKzxO17q8n1o+VFPl4zVmeVKg5vk5GQzNlZPX1cWFxd3zBhLLevt1dHgSi+VOQOjyre5T1Ryiok89u/rwiEOE5jWlT5vXX/5VN5+3bp1JvCMjIw8ZlsNEvXHxJYtW+Shhx4yQwV0Br6za18DP+2tdT7mn/70JzNpT3tktWfTSXtxp06daoZ7XHvttabXVYce6CQr9zotXLjQZKJwDifQyVn6/te2ntW9RvfbnfdpsOy+nfO6BuT6AanqvqoetzaJ6N3/vqptdTU5DeCPV++q9kk78qXXYhfOtqbdPY/93Rq0e+MrKC6VG99cI1sOlE+GTmgVIc9d2s0Esp5o97o8h+XBbGUaSLmPcXWnwxGWL19eYda6TtipboxtQ/l/t51u22VVddiATrR7+umnj7mvbdu25n+9v1OnTvKf//zHjO/U90B7ZDWFlTsdZ/rmm2+asbKaecJJA1Kdta89updffrm8/fbb5gdJZdojrqf49aJtokGwBrb6Y6M29ayJ9rJr735l2mPl7EFWOv5Wh7a40x9D2jvrTuunk7Zq+qFUW/o42oMANAY9ruj+r/ssK4ABaAhlZQ75y3s/S1JK+eTl1s1CZO7EIRJemifeyNKfNDoEQMcSauomDTC0rJNpdGa50gBJb3OfcKMTdWbMmCGbNm2SRx55xIzr1F5BVE2zPegs/YSEBNMD7n7RSVY6TnPz5s3y97//3YxZ1SwRhw8frvKxdCiBDgPQCVs6FMSdDjXQbAUzZ840X6oa1NbkyiuvNAG+bl+beh6P9tzq8IbKNIWbDlFw0qEMuu+4T9TSH0Ma9GqKOCftfdag/pRTTpETpQG7Pi/QGDSA1R9d+kONYBZAQ3jys1/l03X7zPXw4EB5fcJQ6dTq+N/FfhnMam+YBqwaiGggpUMMvvjiCzn33HPN/ampqSaVk5PmBdVeP83v2b9/fzPGVk9tay8ixPRM6thM94uml9KeQc3Pq+2rM/m1jXWilU5cio6ONmM1tU23bt1qgrjKGSPc3XbbbfL444+btForV6503a5B8Kmnnir33XefeS73YQhV0S9dnXSmwbGegr/11ltrrGdNr1GHQ2igrWN89TF1yIIG6JpN4Z133pG//OUvrr/X4Q/aC+yey1jrrim2NEODZjTQCWH6A0mHTDgzGezZs8cMkdD7nXTioT6/tpvSH2Radk/Dpa9Ng2QdGwwAgLd7fWWKvLYyxVwPDGgiM8cNkv4dKuZr9zoOP5OVlaVdcub/yo4cOeLYuHGj+b+hlJWVOYqKisz/jWnChAnmdVW+XH/99Y7ffvvNcdlllzlatGjhCA8Pd/Tq1ctx5513uuq0bNkyR+/evR2hoaGOfv36OVasWGH+dtGiReb+lJQUU167dq3r+WbMmOGIjIx0fPvtt67bXn/9dbNdUlJShbpV9fcqLy/PER0d7Xj66adN+Xj1rOk1Kn3ec8891xETE+OIiopyDBs2zPHhhx8e01ZXX3214/77769w28GDBx1jx451NGvWzNG8eXPHpEmTHDk5Oce8hq+++sp128MPP1xlfebOneva5u2333b07NmzxveuMfY7q5SWljr27dtn/ofnFBcXO/bs2UO7exj7uzVo98bzyc97HQn3f+LodF/55Z0fdlrW7jXFa5U10X/Ej2g2Ax1DqT18VWUzSElJMROXGmpWudVjZj1JF7HQXMHaM2q1mtpd66e9/9r7q+ngGpP2VmtvsU6eq05j7HdW0aEZzuwQTMzwXJvrWQkdKqNp4Mhm4Dns79ag3RvHD9sPynWvJ0lRafkk8NvP6S53n9vDsnavKV6rjL0AJ0xP2+u4UF1ZS4cheLt+/fqZiWYaQDYmzQyhY4d16AQAAN7qt7QcueGN1a5A9qrB8XLXqO5iFwSzOGE6vlTTqulCA5MnTxY7mDhxokk71ph0Qo6mIPP1HnkAgH2lZRfIxDlJkl1QvujPmT1i5MnLf0/LaQecj8IJ03yxegEAAPaRU1AsE+Ykyd6s8hVR+7ZvbiZ8BQfaq6/TXrUFAADACSsqKZOb31wjm/aXLxsbHx0ucyYOlaah9uvnJJitgp/NiYPF2N8AAJ7+3rnvg3WycmuGKbeICJZ5k4dJbKQ9JyETzLoJDg525QYFPMW5vzn3PwAAGtOzX2yWRWv3mOuhQQHy+oQh0jWmcbP7NCb79SU3osDAQGnRooVradOIiIgTHgDtT6m5vIkd2l3rqIGs7m+63+n+B9SV7t+aYq64uNhr93UA3mP+qp0yc8U2c10PGS9dO1AGd2opdkYwW4kuC6mcAW1DBCyam01zsvFF4zl2ancNZJ37HVBXun/rSnXe/MMNgHdYumG/PPzRL67yIxf3kfP72v/7h2C2Ev0yaNu2rUkKrD0dJ0oDqoMHD5olY0nu7Dl2aXcdWkCPLACgsa1JPSy3L1grZUenadx4ZheZMCJBfAHBbDU0wGiIIEODKg1YdGUnbw6qfA3tDgBAue3puXJ9YrIUFJcvijBmQDu5b3Qv8RV8ywOADyxnu2PHDnMdANyl5xTKxLnJcji//GzziK6t5Jkr+0tAgO8MSyKYBQAA8EF5hSVy/bxkST1UnjWnV1ykzL5usIQE+Vb451uvBgAAAFJSWiZT314j63ZnmXK7qDBJnDRMmof5XhpIglkAAAAfy+jzwKJf5KvN6aYcGRYkiZOHSVyUPRdFOB6CWQAAAB/y0vItsnD1LnM9JDBA/jN+iPRoEym+imAWAADARyxMTpUXv9ziKs+4ur+c2qWV+DKCWQAAAB/w1aYD8rdFvy+K8MCFveXi/u3E15FnFgBsvtBL06ZNpbCwkBXAAD+2bnem3Pr2Gik9uirCpNMSZMoZncUfEMwCgI1pANu+fXuzSAjBLOCfUg/my+TEZMkvKjXlC/rGyYMXneQ3xwSGGQAAANjUobwimTA3STJyi0x5aEK0vHDNAJ9aFOF4CGYBAABs6EhRqVkUISUjz5S7xTYzmQvCggPFnxDMAoCN6RK2W7ZskdTUVJazBfyIjo29Y8FaWZuaacqxkaGSOGmotIgIEX9DMAsAPpAgnUAW8K/P/CMfb5ClG9NMuVlokMydNFTioyPEHxHMAgAA2Mjsr7fL/FU7zfWggCYy68+DpE+7KPFXBLMAAAA2sXjtHnl6ySZX+Zkr+8kZ3WPEnxHMAgAA2MC3WzPknvd/dpXvGd1TLh8UL/6OYBYAAMDLbdybLTfO/1GKS8sXRRh3Ske5ZWRXq6vlFQhmAQAAvNiezCMyKTFJcgtLTHlU7zby6KV9/WZRhOMhmAUAmwsPD5ewsDCrqwGgEWTlF8uEOUmSll1oygM6tJCXxw6UQD9aFOF4WM4WAGwsICBAOnToIKGhoeY6AN9RUFwqN8xfLVsP5JpyQqsIeX3CEAkP8a9FEY6HIx8AAICXKStzyF/e/VmSUg6ZcqumITJv8jBp1SzU6qp5HYJZAAAAL/PEZ7/Kp+v3mevhwYEyZ+JQ6dSqqdXV8koEswBgY7ry17Zt22TXrl2sAgb4iNf+t11eX5liruvY2FfGDZT+HVpYXS2vRTALADZXWlpqLgDs79N1+0yvrNMTY/rK2b3aWFonb0cwCwAA4AV+2H5Q7lr4kzjKU8nK7ed0l2uHdbS6Wl6PYBYAAMBiW9Jy5IY3VktRaflwoauHxMtdo7pbXS1bIJgFAACwUFp2gUycmyzZBeWLIpzZI0aeuOxkFkWoJYJZAAAAi+QUFJtAVlf5Un3bN5eZ4wZJcCAhWm3RUgAAABYoKimTm99cI7/uyzbl+Ohwk4KraShrWtUFwSwA2JwuZasrgAGwD4fDIfd/sE5Wbs0w5eiIYLMoQmwkS1PXFaE/ANiYLmHbsWNHE9CynC1gH88t3Swfrt1jrocGBchrE4ZI15hmVlfLliw98j311FMydOhQiYyMlNjYWBkzZoxs3ry5xr9JTEw0A6LdL3oQBwAAsIM3V+2UV77aZq7rHK+Xrh0ogzu1tLpatmVpMPv111/LrbfeKqtWrZJly5ZJcXGxnHfeeZKXl1fj3zVv3lz27dvnuuzcudNjdQYAAKivZRvT5KGPfnGVp1/SR87vG2dpnezO0mEGS5YsOabXVXtof/zxR/nDH/5Q7d9pb2xcHG88AOgSttu3b5fMzExp3bo1Qw0AL7Ym9bDc9s4aKTu6KMJNZ3aV8cMTrK6W7XnVmNmsrCzzf8uWNXe15+bmSqdOncxBfNCgQfLkk09Knz59qty2sLDQXJyys8tnDOrfemIdc30OHeTNmumeRbtbg3b3PG3roqIiKSkp8dhxDeXY361h13ZPyciTKYnJUlBcXu9L+reVv57b3Tavo8zD7V6X5/GaYFYrfeedd8ppp50mffv2rXa7nj17ypw5c6Rfv34m+H3uuedkxIgRsmHDBomPj69yXO706dOPuT09PV0KCgrEE69L66k7AD0mnkO7W4N2t67N8/Pz5cCBAxIU5DWHdZ/H/m4NO7b7wbxi+b93N8mh/GJTHtIhUu75Q5xkZKSLXZR5uN1zcnJqvW0Th9bKC9x8883y+eefy8qVK6sMSquj42x79+4tY8eOlccee6xWPbMdOnSQw4cPm7G3nnjzNXCOiYmxzYfOF9Du1qDdrWnz3377zRzb9EwVwaznsL9bw27tnldYIn96LUnW7yk/+9wzLlIW/t8p0jwsWOykzMPtrse06OhoE0AfL17ziqPe1KlT5ZNPPpFvvvmmToGsCg4OloEDB8rWrVurvF9zL1aVf1HfCE99CHSMryefD+Vod2vQ7p7nbGva3fPY361hl3YvKS2T2xf85Apk20aFSeKkodIiwp55oZt4sN3r8hyW7gXaKayB7KJFi+S///2vdO7cuc6PUVpaKuvXr5e2bds2Sh0BAADqE+M8sOgX+Wpz+VCCyLAgsyhC26hwq6vmcyztmdW0XG+//bZ89NFHJtfs/v37ze1RUVESHl7+Zo8fP17at29vxr6qRx99VE499VTp1q2bmb377LPPmtRcU6ZMsfKlAAAAuLy0fIssXL3LXA8JDJD/jB8iPdpEWl0tn2RpMDtr1izz/8iRIyvcPnfuXJk4caK5npqaWqGrWce63nDDDSbw1bEUgwcPlu+++05OOukkD9ceALxDSEiIuQDwDu8m75IXv9ziKs+4ur+c2qWVpXXyZZYGs7WZe7ZixYoK5RdeeMFcAADl48oSEhJMJgNvHz8I+IOvNh+QaYvWu8oPXNhbLu7fztI6+TqOfAAAAA1g/e4sufWtNVJ6dFWESaclyJQz6j4fCHVDMAsAAHCCdh3Kl0mJyZJfVGrKF54cJw9edJLJAIDGRTALADamuR937Nghe/futc1KQoCvOZRXJBPmJElGbnle+6EJ0fL81QMkIIBA1hMIZgHA5nQ5W70A8LyC4lKZMi9ZtmfkmXK32GYmc0FYcKDVVfMbBLMAAAD1oGNjb39nraxJzTTl2MjQo4sikF3EkwhmAQAA6pGRafr/2yBLN6aZcrPQIJk7aajER0dYXTW/QzALAABQR7O/3i5vfL/TXA8KaCKz/jxI+rSLsrpafolgFgAAoA4Wr90jTy/Z5Co/fUU/OaN7jKV18mcEswAAALX07dYMuef9n13le0b3lCsGx1taJ39HMAsANhcUFGQuABrXxr3ZcuP8H6W4tHxRhHGndJRbRna1ulp+j6MfANiYLmHbpUsXlrMFGtmezCMyKTFJcgtLTHlU7zby6KV9WRTBC3DkAwAAqEFWfrFMnJMkadnliyIM6NBCXh47UAJZFMErEMwCAADUsCjCDfNXy5YDuaac0CpCXp8wRMJDWBTBWxDMAoCN6RK2qampsm/fPpazBRpYWZlD/vLez5KUcsiUWzcLkXmTh0mrZqFWVw1uCGYBwOYKCgqksLD89CeAhvPkZ7/Kp+v2mevhwYHy+oSh0qlVU6urhUoIZgEAACp5fWWKvLYyxVzXsbGvjBso/Tu0sLpaqALBLAAAgJvP1u+Txz/d6Co/MaavnN2rjaV1QvUIZgEAAI7S8bF3LvxJHOWpZOX2c7rLtcM6Wl0t1IBgFgAAQES2pOXIlHnJUlRSPpnyqsHxcteo7lZXC8dBMAsAAPxeWnaBTJybLNkF5YsinNkjRp68/GQWRbABglkAsLnAwEBzAVA/OQXFMmFOklnlS/Vt31xmjhskwYGESXbAcrYAYGO6hG3Xrl1ZzhaoJx1ScNObP8qm/TmmHB8dLnMmDpWmoYRIdsGRDwAA+CWHwyH3fbBOvt160JRbRASbRRFiI8OsrhrqgGAWAAD4pWe/2CyL1u4x10ODAswytV1jmlldLdQRwSwA2JguYbtr1y7Zv38/y9kCdTB/1U6ZuWKbua5zvF66dqAM7tTS6mqhHghmAcDmjhw5Ypa0BVA7Szfsl4c/+sVVnn5JHzm/b5yldUL9EcwCAAC/sSb1sNy+YK2UHV0U4aYzu8r44QlWVwsngGAWAAD4he3puXJ9YrIUFJcPyRkzoJ3cO7qn1dXCCSKYBQAAPi89p1AmzE2Sw/nFpjyiayt55sr+EhDAogh2RzALAAB8Wl5hiVw/L1l2HSpfFKFXXKTMvm6whAQRBvkC3kUAAOCzSkrLZOrba2Td7ixTbhsVJomThknzsGCrq4YGQjALADana8ez+hdQ9aIIDyz6Rb7anG7KkWFBJpCNi2JRBF/CWm0AYGMaxHbv3p3lbIEqvLR8iyxcvctcDwkMkP+MHyI94yKtrhYaGEc+AADgcxYmp8qLX25xlWdc3V9O7dLK0jqhcRDMAgAAn/LV5gPyt0W/L4rw94t6y8X921laJzQeglkAsPmYwD179khaWpq5Dvi79buz5Na31kjp0VURJp2WINef3tnqaqEREcwCgI1pAJuXl2eWtCWYhb9LPZgvkxKTJL+o1JQvPDlOHrzoJDNJEr6LYBYAANjeobwisyhCRm6RKQ9NiJbnrx7Aogh+gGAWAADY2pGiUpkyL1lSMvJMuVtsM5O5ICw40OqqwQMIZgEAgG3p2Ng7FqyVNamZphwbGSqJk4ZKi4gQq6sGDyGYBQAAtqTjxB/5eIMs3Zhmys1Cg2TupKESHx1hddXgQQSzAADAlmZ/vV3mr9pprgcFNJFZfx4kfdpFWV0t+FMw+9RTT8nQoUMlMjJSYmNjZcyYMbJ58+bj/t17770nvXr1krCwMDn55JPls88+80h9AQCAd1j80x55eskmV/mZK/vJGd1jLK0T/DCY/frrr+XWW2+VVatWybJly6S4uFjOO+88k2amOt99952MHTtWrr/+elm7dq0JgPXyyy+/J0cGAH+hS9j26NFDEhISWM4WfiM5NVvu+2C9q3zP6J5y+aB4S+sE6zRxeFFiwvT0dNNDq0HuH/7whyq3ueaaa0yw+8knn7huO/XUU2XAgAEye/bs4z5Hdna2REVFSVZWljRv3lwaW1lZmVkzXV8XXzSeQ7tbg3a3Bu1uDdrdGhv2ZMrV//5e8orKTHncKR3l8TF9ySXrY/t7XeK1IPEiWmHVsmXLarf5/vvv5e67765w2+jRo2Xx4sVVbl9YWGgu7o3jfFP00tj0OfT3gieeC7+j3a1Bu1uDdrcG7e55ezKPyOTE1a5A9pxesfLwH3ub98GL+uZ8UpmH9/e6PI/XBLNa6TvvvFNOO+006du3b7Xb7d+/X9q0aVPhNi3r7dWNy50+fXqVvcAFBQXiidelQbruAPxy9xza3Rq0u+dpW2tvSW5urmn/wEDyanoK+7tnZReUyI3vbpa0nPIOqj5tIuTBc9rLoYMZVlfNL5R5eH/PycmxXzCrY2d13OvKlSsb9HGnTZtWoSdXe2Y7dOggMTExHhtmoKc+9Pk42HkO7W4N2t2aNj98+LAJYrXdg4K85rDu89jfPaewpFRun5MsKYfKO6Hio0JlzuRTJCYyzOqq+Y0yD+/vOsm/trziqDd16lQzBvabb76R+PiaB3DHxcVJWlp5PjknLevtVQkNDTWXyvSN8NTBR998Tz4fytHu1qDdPc/Z1rS757G/N76yMof89f31krTjsCm3ahoiL17W3QSytLtneXJ/r8tzWLoXaFe1BrKLFi2S//73v9K5c+fj/s3w4cNl+fLlFW7TTAh6OwAA8C1PfvarfLpun7keHhwor00YLPEtju2kgv8KsHpowZtvvilvv/22yTWr4171cuTIEdc248ePN0MFnO644w5ZsmSJzJgxQzZt2iSPPPKIrF692gTFAADAd7y+MkVeW5lirgc0EXll3EDpH9/C6mrBy1gazM6aNcsMJh45cqS0bdvWdVm4cKFrm9TUVNm3r/wXmRoxYoQJfl999VXp37+/vP/++yaTQU2TxgAAgL1ob+zjn250lZ+47GQ5u1fFCeCA5WNma5NGY8WKFcfcdtVVV5kLAADwPT9sPyh3LfxJnGHC7Wd3k7HDOlpdLXgpRk4DAACvsSUtR254Y7UUlZbnGb1qcLzcdW4Pq6sFL+YV2QwAAPWf8dutWzeTa5aZ3bC7tOwCmTg32eSUVX/oESNPXn4yq3uhRhz5AMDmSA0FX5BTUCwT5iSZVb5U3/bNZea4QRIcyL6NmrGHAAAASxWVlMnNb66RTfvLV32Kjw6XOROHSrNQTiDj+AhmAcDGdCKtpjTMyMhgbXrYku6393+wTlZuLV+WtkVEsMybPExiWd0LtUQwCwA2DwR0me7c3FyCWdjSs19slg/X7jHXQ4MC5PUJQ6RrTDOrqwUbIZgFAACWmL9qp8xcsc1c1zleL107UAZ3aml1tWAzBLMAAMDjlm1Mk4c/+sVVfuTiPnJ+3zhL6wR7IpgFAAAetSb1sNz2zhopOzoy5sYzu8iEEQlWVws2RTALAAA8Znt6rlyfmCwFxeWLIlw6oJ3cN7qX1dWCjRHMAgAAj0jPKZQJc5PkcH6xKQ/v0kqeubKfBASwKALqj2AWAAA0urzCErl+XrLsOlS+KEKvuEj59/jBEhoUaHXVYHNkIwYAG9OVv7p27cpytvBqJaVlMvXtNbJud5Ypt40Kk7mThkrzsGCrqwYfwJEPAGwuMDDQXABvpPmPH1j0i3y1Od2UI8OCzKIIbaPCra4afATBLAAAaDT/XL5VFq7eZa6HBAbIq9cNkR5tIq2uFnwIwSwA2LzXKy0tTQ4ePMgKYPA67ybvkhe+/M1Vfu7q/jK8aytL6wTfQzALADamAWxWVpbk5OQQzMKrfLX5gExbtN5VfuDC3nJJ/3aW1gm+iWAWAAA0qHW7M+XWt9ZI6dFVESadliBTzuhsdbXgowhmAQBAg0k9mC+TE5Mlv6jUlC88OU4evOgkadKEXLJoHASzAACgQRzKK5KJc5MkI7fIlIcmRMvzVw9gUQQ0KoJZAABwwo4UlcqUecmyPSPPlLvFNpP/jB8iYcGkjUPjIpgFAAAnRMfG3rFgraxJzTTl2MhQSZw0VFpEhFhdNfgBglkAAFBvmkXjkY83yNKNaabcNCTQrO4VHx1hddXgJ1jOFgBsTCfVdO7c2SxnywQbWGH219tl/qqd5npQQBOZfd1g6dMuyupqwY/QMwsANqYBbHBwsLkQzMLTFq/dI08v2eQqP31FPzmje4yldYL/IZgFAAB19u3WDLnn/Z9d5XtG95QrBsdbWif4J4JZALD5eMX09HQ5fPgwK4DBYzbuzZYb5/8oxaXl+9y4UzrKLSO7Wl0t+CmCWQCwMQ1gNZDVJW0JZuEJezKPyKTEJMktLDHlUb1jZfolfRjmAssQzAIAgFrJyi+WiXOSJC270JQHdGghL48dJEGBhBOwDnsfAAA4roLiUrlh/mrZciDXlBNaRcjrE4ZIeAiLIsBaBLMAAKBGZWUO+ct7P0tSyiFTbtU0ROZNHiatmoVaXTWAYBYAANTsyc9+lU/X7TPXw4MDZc7EodKpVVOrqwUYBLMAAKBar69MkddWppjrgQFN5JVxA6V/hxZWVwtwIZgFAABV0t7Yxz/d6Co/MaavnN2rjaV1AipjOVsAsDFNh9SpUyeJiIggNRIalI6Pvevdn8SZ8e32c7rLtcM6Wl0t4Bj0zAKAjWkAGxoaKiEhIQSzaDBb0nJkyrxkKSopM+UrB8fLXaO6W10toEoEswAAwCUtu0Amzk2W7ILyRRHO7BEjT11+Mj+W4LUIZgHAxnTVr4yMDMnMzGQFMJywnIJimTAnyazypfq2by4zxw2SYBZFgBdj7wQAG9MA9tChQwSzOGE6pODmN9fIpv05phwfHW5ScDUNZXoNvBvBLAAAfk5/CN3/wTpZuTXDlFtEBJtFEWIjw6yuGnBcBLMAAPi5Z7/YLB+u3WOuhwYFmGVqu8Y0s7paQOMFs126dJGDBw8ec7ue5tL7AACAPcxftVNmrthmruscr5euHSiDO7W0ulpA4wazO3bskNLS0mNuLywslD17yn/Z1cY333wjF198sbRr187Mkly8eHGN269YscJsV/myf//++rwMAAD82tIN++Xhj35xlR+5uI+c3zfO0joBdVWnUd0ff/yx6/oXX3whUVFRrrIGt8uXL5eEhIRaP15eXp70799fJk+eLJdffnmt/27z5s3SvHlzVzk2NrbWfwsAAETWpB6W2xeslbKj8wZvPLOLTBhR++9wwJbB7JgxY8z/2hs6YcKECvcFBwebQHbGjBm1frwLLrjAXOpKg9cWLVgXGgCA+tienivXJyZLQXH5oghjBrST+0b3srpaQOMHs2Vl5Tt9586dJTk5WVq3bi1WGDBggBnS0LdvX3nkkUfktNNOq3Zb3U4vTtnZ2a7X4nw9jUmfQ2eJeuK58Dva3Rq0u+dpe8fHx0tYWBht72F23d/TcwpNLtnD+cWmPLxLK/nH5Sfr3iRlzm5aL2bXdre7Mg+3e12ep17J41JSUsQKbdu2ldmzZ8uQIUNMgPraa6/JyJEj5YcffpBBgwZV+TdPPfWUTJ8+/Zjb09PTpaCgwCNvRlZWltkBAgJIHuEptLs1aHfr2v3IkSPmuEa7e44d9/f8olK55f3fZNfh8kURurYOl8dGd5DMQ+UpuezAju3uC8o83O45OeX5jmujiaMeWbYfffTRGu9/6KGH6vqQZujCokWLXEMZauvMM8+Ujh07yvz582vdM9uhQwc5fPhwhXG3jfnm6xdMTEwMHzoPot2tQbtbg3a3ht3avaS0TP7vzTWyYnO6KbeNCpP3bzpV2kaFi53Yrd19RZmH213jtejoaBNAHy9eq1fPrAad7oqLi01vbVBQkHTt2rVewWx9DRs2TFauXFnt/aGhoeZSmb4RnvoQaKDuyedDOdrdGrS7Z2l/hKZF1AO/zieg3T3LLvu77icPfrTRFchGhgVJ4qRh0j66qdiRXdrd1zTxYLvX5TnqFcyuXbv2mNv0QDpx4kS57LLLxJN++uknM/wAAPyRBikZGRksZ4savbR8iyxcvctcDwkMkP+MHyI94yKtrhbQIBpswWXtAtaxqZo39rrrrqvV3+Tm5srWrVtdZe3d1eC0ZcuWZujAtGnTTN7aN954w9z/4osvmslnffr0MeNddczsf//7X1m6dGlDvQwAAHzKu8m75MUvt7jKM67uL6d2aWVpnQCvDGaVjmvQS22tXr1azjrrLFf57rvvNv9r2q/ExETZt2+fpKamuu4vKiqSv/zlLybAjYiIkH79+smXX35Z4TEAAEC5rzYfkGmL1rvKD1zYWy7u387SOgFeEcz+85//rFDWU1saeOokrLrkjdVMBDWdFtOA1t29995rLgAAoGbrdmfKrW+tkdKj6bYmnZYgU87obHW1AO8IZl944YVjBunq7DbtUdWhAQAAwDqpB/NlcmKyScWlLjw5Th686CQzgQfwNbbKMwsAAGp2KK9IJsxNkozcIlMemhAtz189QAICCGThm044t8KuXbvMBQAAWOtIUalMmZcsKRl5ptwttpnJXBAWHGh11QDvCmZLSkrkwQcflKioKElISDAXvf73v//d5JwFAHiGnjbW5Wzj4uI4hezndGzsHQvWyprUTFOOjQyVxElDpUVEiNVVA7xvmMFtt90mH374oTzzzDMyfPhwc9v3338vjzzyiBw8eFBmzZrV0PUEAFRBA1jN7hIWFkYw68d0MvUjH2+QpRvTTLlZaJDMnTRU4qMjrK4a4J3B7Ntvvy0LFiyokLlA02TpMrFjx44lmAUAwINmfb1N5q/aaa4HBTSRWX8eJH3aRVldLcB7hxno8rA6tKAyXdAgJITTGQDg6eVsc3JyWAHMTy1au1ueWbLZVX76in5yRvcYS+sEeH0wO3XqVHnssceksLDQdZtef+KJJ8x9AADP0AD2wIEDZogXwaz/+XZrhtz7/jpX+Z7RPeWKwfGW1gmwxTCDtWvXyvLly82kg/79+5vbfv75Z7NC1znnnCOXX365a1sdWwsAABrWxr3ZcuP8H6W4tPxHzLhTOsotI7taXS3AHsFsixYt5Iorrqhwm46XBQAAjW9P5hGZlJgkuYUlpjyqdxt59NK+TAKEX6pXMDt37tyGrwkAADiurPximTgnSdKyy4f6DejQQl4eO1ACWRQBfqpeY2bPPvtsM+GgsuzsbHMfAABoeAXFpXLD/NWy5UCuKSe0ipDXJwyR8BAWRYD/qlcwu2LFCjM+trKCggL53//+1xD1AgAAbsrKHPKX936WpJRDpty6WYjMmzxMWjULtbpqgH2GGaxb9/uMyY0bN8r+/ftd5dLSUlmyZIm0b9++YWsIAADkyc9+lU/X7TPXw4MD5fUJQ6VTq6ZWVwuwVzA7YMAAM7hcL1UNJwgPD5eXX365IesHAKiBHo/btWtncnwz+cd3vb4yRV5bmWKu69jYV8YNlP4dWlhdLcB+wWxKSorJY9ilSxdJSkqSmJjfkzLrgTQ2NlYCAxm3AwCeogFss2bNJD8/n2DWR2lv7OOfbnSVHx/TV87u1cbSOgG2DWY7depk/i8rK2us+gAAgKN+2H5Q7lr4kzjXw7j9nO4ydlhHq6sF2D811xtvvFHj/ePHj69vfQAAdaBny7KysiQ3N7fC2TLY329pOXLDG6ulqLS8A+mqwfFy16juVlcL8I1g9o477qhQLi4uNqe4dKhBREQEwSwAeDCYTUtLM+kSExISrK4OGkhadoHJJZtdUL4owpk9YuTJy09mKAnQUKm5Dh8+XOGiPQKbN2+W008/Xd555536PCQAABCRnIJimTg3WfZmFZhy3/bNZea4QRIcWK+vbMDnNdgno3v37vKPf/zjmF5bAABQO0UlZXLTmz/Kr/uyTTk+OlzmTBwqTUPrdSIV8AsN+jMvKChI9u7d25APCQCA3wwZue+DdfLt1oOm3CIi2CyKEBsZZnXVAK9Wr596H3/88TEfwH379sm//vUvOe200xqqbgAA+I1nv9gsi9buMddDgwLMMrVdY5pZXS3AN4PZMWPGVCjrgHSdRasLKcyYMaOh6gYAgF+Yv2qnzFyxzVzXOV4vXTtQBndqaXW1AN8NZp15ZtPT083/pIMBAKB+lm7YLw9/9IurPP2SPnJ+3zhL6wT49JhZTf9y6623SuvWrSUuLs5c9PrUqVPNfQAAz9EzY23btjWdCqRtsp81qYfl9gVrpezoogg3ndlVxg8nxRrQaD2zhw4dkuHDh8uePXtk3Lhx0rt3b3P7xo0bJTExUZYvXy7fffedREdH16kSAID60QA2MjJSjhw5QjBrMykZeTJl3mopKC4/2zlmQDu5d3RPq6sF+HYw++ijj5qFEbZt2yZt2rQ55r7zzjvP/P/CCy80dD0BAPAZ6TmFMmFOkhzKKzLlEV1byTNX9peAAH6QAI06zGDx4sXy3HPPHRPIKh1u8Mwzz8iiRYvqXAkAQP1oNpmcnBzJy8sz1+H98gpL5Pp5yZJ6KN+Ue8VFyuzrBktIEIsiAPVRp0+Opt/q06dPtff37dtX9u/fX6+KAADqzpkaUSfkEsx6v5LSMpn69hpZtzvLlNtGhUnipGHSPCzY6qoB/hHM6kSvHTt2VHt/SkqKtGxJKhEAACrTHxsPLPpFvtpcngkoMizILIoQF8WiCIDHgtnRo0fLAw88IEVF5WN83BUWFsqDDz4o559//glVCAAAX/TS8i2ycPUucz0kMEBevW6I9GgTaXW1AP+bADZkyBDp3r27Sc/Vq1cv80vz119/lZkzZ5qAdv78+Y1XWwAAbOjd5F3y4pdbXOXnru4vw7u2srROgF8Gs/Hx8fL999/LLbfcItOmTXONz9J0MOeee65ZzrZDhw6NVVcAAGznq80HZNqi9a7yAxf2lkv6t7O0ToBfrwDWuXNn+fzzz+Xw4cOyZUv5r8xu3boxVhYAgErW7c6UW99aI6VHV0WYOCJBppzR2epqAT6lXsvZKl0YYdiwYQ1bGwAAfETqwXyZnJgs+UWlpnxB3zh58I8nsbgF4C3BLADAehoYae7voKAggiQvooshTJibJBm55ROmhyZEywvXDJBAFkUAGhzBLADYmAawUVFRZgIuwax3OFJUKlPmJZvlalW32Gbyn/FDJCw40OqqAT6J5UYAAGggOjb2jgVrZU1qpinHRoZK4qSh0iIixOqqAT6LYBYAbEyzyuTm5kp+fj4rgFlM23/6/9sgSzemmXKz0CCZO2moxEdHWF01wKcRzAKAzQOovXv3yoEDBwhmLTb76+3yxvc7zfWggCYy68+DpE+7KKurBfg8glkAAE7Q4rV75Oklm1zlZ67sJ2d0j7G0ToC/sDSY/eabb+Tiiy+Wdu3amYkLixcvPu7frFixQgYNGiShoaEmv21iYqJH6goAQFW+3Zoh97z/s6t8z+iecvmgeEvrBPgTS4PZvLw86d+/v7zyyiu12j4lJUUuuugiOeuss+Snn36SO++8U6ZMmSJffPFFo9cVAIDKtqTny81vrZHi0vIhHuNO6Si3jOxqdbUAv2Jpaq4LLrjAXGpr9uzZZgWyGTNmmHLv3r1l5cqV8sILL8jo0aMbsaYAAFS0J/OI3L14q+QWli+KMKp3rEy/pA8p0gAPs1We2e+//15GjRpV4TYNYrWHtjqae1EvTtnZ2eb/srIyc2ls+hw6KcMTz4Xf0e7WoN09z/1Y5qnjGkSyjhTLpMRkSc8rNuUBHaLkpWsGiK6JwHvQuDjO+Ee7l9XheWwVzO7fv9+sdONOyxqgHjlyRMLDw4/5m6eeekqmT59+zO3p6elSUFAgnngzsrKyzA4QEMB8O0+h3a1Bu1vX5pqaSzMa6EpgaFyFJWVy56ItsvVA+aII8VGh8tQFnSQn86DkWF05P8Bxxj/aPSen9p8mnz/qTZs2Te6++25XWQPfDh06SExMjDRv3twjb76ectLn40PnObS7NWh3z9MvlpCQEDl06JD5cR8YyCpTjalMF0VY+JOs3ZNrytHhQZI4eah0iYm0ump+g+OMf7R7WFiYbwazcXFxkpZWnozaScsalFbVK6s064FeKtM3wlMfAn3zPfl8KEe7W4N297yWLVtKSUmJCWRp98b15Gcb5dP1+8318OBAmXFpNxPI0u6exXHG99s9oA7PYau9YPjw4bJ8+fIKty1btszcDgBAY3p9ZYq8tjLFXA8MaCIvjx0gJ8U1tbpagN+zNJjVJRg1xZZenKm39HpqaqpriMD48eNd2990002yfft2uffee2XTpk0yc+ZMeffdd+Wuu+6y7DUAgNXDDHS8rM4BYAWwxvPpun3y+KcbXeUnxvSVs3vFWlonAF4QzK5evVoGDhxoLkrHtur1hx56yJT37dvnCmyVpuX69NNPTW+s5qfVFF2vvfYaabkA+C0NYHfv3m0myBLMNo6klENy18KfxNm8t5/TXa4d1tHqagHwhjGzI0eOrPHgW9XqXvo3a9eubeSaAQAgsiUtR6bMS5ai0vI0QVcPiZe7RnW3uloA7DpmFgAAT0nLLpCJc5Mlu6DElM/sESNPXHYyiyIAXoZgFgCASnIKimXCnCSzypfq2765zBw3SIID+doEvA2fSgAA3BSVlMlNb/4om/aXJ22Pjw6XOROHStNQW2WzBPwGwSwAAEfpPI77Plgn3249aMotIoJl3uRhEhtZ+wTuADyLYBYAgKOe+WKzLFq7x1wPDQqQ1ycMka4xzayuFoAacM4EAGxMJyO1bt3a/M/EpBMzf9VOmbVim7muTfnStQNlcKeWVlcLwHEQzAKAjWkA61zOlmC2/pZu2C8Pf/SLq/zIxX3k/L5xltYJQO0wzAAA4Nd+3HlYbntnrZQdTXt+45ldZMKIBKurBaCWCGYBwOYTlnQp28LCQlYAq4ft6blmUYTCkvJFEcYMaCf3je5ldbUA1AHBLADYmAawuuy3Lv9NMFs36TmFMmFukhzOLzblEV1byTNX9peAAIZrAHZCMAsA8Dt5hSUyOTFZdh0qXxShV1ykzL5usIQE8bUI2A2fWgCAXykpLZOpb6+R9XuyTLldVJgkThomzcOCra4agHogmAUA+A0divHAol/kq83pphwZFiSJk4dJXBSLIgB2RTALAPAbLy3fIgtX7zLXQwID5D/jh0iPNpFWVwvACSCYBQD4hYXJqfLil1tc5RlX95dTu7SytE4AThzBLADA53216YD8bdHviyL8/aLecnH/dpbWCUDDYAUwAPCBFcCc13Gsdbsz5Za31kjp0VURJp2WINef3tnqagFoIASzAGBjGsC2bt1aysrKCGarkHow36TgOlJcasoXnhwnD150Em0F+BCGGQAAfNKhvCKZODdJMnKLTHloQrQ8f/UAFkUAfAzBLADYPNWULmVbVFTECmBujhSVyvXzkmV7Rp4pd41pajIXhAUHWl01AA2MYBYAbEwD2J07d8revXsJZo/SsbG3L1gra1MzTTk2MlTmTR4mLSJCrK4agEZAMAsA8Bka0D/y8QZZtjHNlJuFBsncSUMlPjrC6qoBaCQEswAAnzH76+0yf9VOcz0ooInM+vMg6dMuyupqAWhEBLMAAJ+weO0eeXrJJlf5mSv7yRndYyytE4DGRzALALC9b7dmyD3v/+wq3zO6p1w+KN7SOgHwDIJZAICtbdybLTfO/1GKS8snwI07paPcMrKr1dUC4CEEswAA29qTeUQmJSZJbmGJKY/q3UYevbQviyIAfoQVwADAxjRoi46ONrP4/S2Ay8ovlolzkiQtu9CUB3ZsIS+PHSiBLIoA+BWCWQCwMQ1gY2Ji/C6YLSgulRvmr5YtB3JNuXPrpvL6hKESHsKiCIC/YZgBAMBWysoc8pf3fpaklEOm3LpZiCROGiotm7IoAuCPCGYBwMa0R7a4uNhc/GUFsCc/+1U+XbfPXA8PDjQ9sp1aNbW6WgAsQjALADamAWxKSors2bPHL4LZ11emyGsrU8x1HRv7yriB0r9DC6urBcBCBLMAAFvQ3tjHP93oKj8+pq+c3auNpXUCYD2CWQCA1/th+0G5a+FP4ux8vv2c7jJ2WEerqwXACxDMAgC82pa0HLnhjdVSVFpmylcNjpe7RnW3uloAvATBLADAa6VlF8jEucmSXVC+KMKZPWLkyctP9qs0ZABqRjALAPBKOQXFMmFOklnlS/Vt31xmjhskwYF8dQH4HUcEAIDXKSopk5vfXCOb9ueYcnx0uMyZOFSahrLWD4CKOCoAgI3p6faoqCgpLS31mVPvmmLs/g/WycqtGaYcHREs8yYPk9jIMKurBsALEcwCgI1pANumTRvzv68Es89+sVk+XLvHXA8NCpDXJgyRrjHNrK4WAC/FMAMAgNeYv2qnzFyxzVzX2PyfYwfK4E4tra4WAC9GMAsANqdDDPRid0s37JeHP/rFVZ5+SR8Z3SfO0joB8H5eEcy+8sorkpCQIGFhYXLKKadIUlJStdsmJia6Tqc5L/p3AOCPysrKZNu2bbJr1y5z3a7WpB6W2xeslbKjiyLcdGZXGT88wepqAbABy4PZhQsXyt133y0PP/ywrFmzRvr37y+jR4+WAwcOVPs3zZs3l3379rkuO3fu9GidAQANJyUjT6bMWy0FxeXB+JgB7eTe0T2trhYAm7A8mH3++eflhhtukEmTJslJJ50ks2fPloiICJkzZ061f6O9sXFxca6LTn4AANhPek6hySV7KK/IlEd0bSXPXNlfAgJ8YzIbAB/PZlBUVCQ//vijTJs2zXVbQECAjBo1Sr7//vtq/y43N1c6depkTqkNGjRInnzySenTp0+V2xYWFpqLU3Z2tvlf/9YTp+T0OTTNjJ1P/9kR7W4N2t3z3I9lnjquNZS8whKZnJgsqYfyTblnXKTMHDdQggLKX4u3Y3+3Bu3uH+1eVofnsTSYzcjIMJMWKvesannTpk1V/k3Pnj1Nr22/fv0kKytLnnvuORkxYoRs2LBB4uPjj9n+qaeekunTpx9ze3p6uhQUFIgn3gytp+4AGqjDM2h3a9Du1rV5fn6+GZ4VFGSPjIslZQ659+Otsn5PeQdDbLNgefaiBCnIPiwF5Td5PfZ3a9Du/tHuOTnlC6bUhj2Oem6GDx9uLk4ayPbu3Vv+/e9/y2OPPXbM9trrq2Ny3XtmO3ToIDExMWbsrSfefB0Woc/Hh85zaHdr0O7WtHlmZqZp99jYWFsEs/plOG3RL/LdjvKoNTIsSN64/hTp0SZS7IT93Rq0u3+0e1gdJvdbetRr3bq1BAYGSlpaWoXbtaxjYWsjODhYBg4cKFu3bq3y/tDQUHOpTN8IT30I9M335POhHO1uDdrd85xtbZd2f/HL3+Td1bvN9ZDAAPnP+CHSq22U2BH7uzVod99v94A6PIele0FISIgMHjxYli9fXiHy17J772tNdJjC+vXrpW3bto1YUwDw3i8XPcvUrFkzW6wA9m7yLnnxyy2u8oyr+8upXVpZWicA9mb5+SgdAjBhwgQZMmSIDBs2TF588UXJy8sz2Q3U+PHjpX379mbsq3r00Ufl1FNPlW7duplTa88++6xJzTVlyhSLXwkAeJ4zu4v2Ynh7MPvV5gMybdF6V/mBC3vLxf3bWVonAPZneTB7zTXXmMlYDz30kOzfv18GDBggS5YscU0KS01NrdDVfPjwYZPKS7eNjo42PbvfffedSesFAPBO63Znyq1vrZHSo6siTDotQaac0dnqagHwAU0cOhLfj+gEsKioKDMjz1MTwHSGsU7MYGyP59Du1qDdrVFSUmLa3dlD621SD+bL5bO+lYzc8lyyF54cJy+PHSSBNs8ly/5uDdrdP9o9uw7xGnsBANj8C0YnwOpZLG/Mu6mLIUycm+QKZIcmRMvzVw+wfSALwHsQzAIAGsWRolKZMi9ZtmfkmXK32GYmc0FYcKDVVQPgQwhmAQANTsfG3rFgraxJzTTl2MhQSZw0VFpEhFhdNQA+hmAWANCgdCrG9P+3QZZuLM8h3iw0SOZOGirx0RFWVw2ADyKYBQA0qNlfb5c3vt9prgcFNJFZfx4kfdrZc1EEAN6PYBYA0GAWr90jTy/Z5Co/fUU/OaN7jKV1AuDbCGYBAA3i260Zcs/7P7vK94zuKVcMjre0TgB8n+WLJgAA6k9X/dKlbIuLiy1dAWzj3my5cf6PUlxanrp83Ckd5ZaRXS2rDwD/QTALADamAWy7du0kKCjIsmB2T+YRmZSYJLmFJaY8qncbefTSvl6/vC4A38AwAwBAvWXlF8vEOUmSll1oygM6tJCXxw5kUQQAHkMwCwCol4LiUrlh/mrZciDXlBNaRcjrE4ZIeAiLIgDwHIJZALAxXcL2t99+kx07dnh0OduyMof85b2fJSnlkCm3ahoi8yYPk1bNQj1WBwBQBLMAgDp78rNf5dN1+8z18OBAmTNxqHRq1dTqagHwQwSzAIA6eX1liry2MsVc17Gxr4wbKP07tLC6WgD8FMEsAKDWtDf28U83uspPjOkrZ/dqY2mdAPg3glkAQK3o+Ni73v1JHOWpZOX2c7rLtcM6Wl0tAH6OYBYAcFxb0nJkyrxkKSopn2R21eB4uWtUd6urBQAEswCAmqVlF8jEucmSXVC+KMKZPWLkyctPZlEEAF6BFcAAwMY0oGzatKkUFhY2SnCZU1BsAlld5Uv1bd9cZo4bJMGB9IUA8A4EswBgYxrAtm/fXoKDgxs8mNUhBTe/uUZ+3ZdtyvHR4SYFV9NQvjoAeA9+WgMAjuFwOOT+D9bJyq0ZptwiItgsihAbGWZ11QCgAoJZAMAxnv1is3y4do+5HhoUYJap7RrTzOpqAcAxCGYBwMZ0CdstW7ZIampqgy1nO3/VTpm5Ypu5riMXXrp2oAzu1LJBHhsAGhrBLAD4wJCAhgpkl27YLw9/9IurPP2SPnJ+37gGeWwAaAwEswAA48edh+W2d9ZK2dFFEW48s4uMH55gdbUAoEYEswAA2Z6eaxZFKDy6KMKYAe3kvtG9rK4WABwXwSwA+Ln0nEKZMDdJDucXm/KIrq3kmSv7S0AAiyIA8H4EswDgx/IKS+T6ecmy61D5ogi94iJl9nWDJSSIrwcA9sDRCgD8VElpmUx9e42s251lym2jwiRx0jBpHhZsddUAoNYIZgHA5sLDwyUsLKzOGRD+vvgX+WpzuilHhgWZRRHiolgUAYC9sCYhANhYQECAdOjQQUJDQ8312vrn8q2yIHmXuR4SGCD/GT9EerSJbMSaAkDjoGcWAPzMu8m75IUvf3OVZ1zdX07t0srSOgFAfRHMAoAf+WrzAZm2aL2r/PeLesvF/dtZWicAOBEEswBgY7ry17Zt22TXrl3HXQVs3e5MufWtNVJ6dFWESaclyPWnd/ZQTQGgcRDMAoDNlZaWmktNUg/my+TEZMkvKt/uwpPj5MGLTpImTcglC8DeCGYBwMcdyisyiyJk5BaZ8tCEaHn+6gEsigDAJxDMAoAPO1JUahZFSMnIM+Vusc1M5oKw4ECrqwYADYJgFgB8lI6NvWPBWlmbmmnKMZGhkjhpqLSICLG6agDQYAhmAcAH6aIIj3y8QZZuTDPlpiGBJpCNj46wumoA0KAIZgHAB836epvMX7XTXA8KaCKzrxssfdpFWV0tAGhwBLMAYHO6lK2uAOa0aO1ueWbJZlf56Sv6yRndYyyqHQA0LpazBQAb0yVsO3bsaAJavf7t1gy59/11rvvvGd1Trhgcb2kdAaAxEcwCgA8oKXPI6ytT5PllW6S4tHxRhHGndJRbRna1umoA4PvDDF555RVJSEgwPQunnHKKJCUl1bj9e++9J7169TLbn3zyyfLZZ595rK4A4G1+3pUpk9/5VZ74bJMcKS5fFGFU7zby6KV9WRQBgM+zvGd24cKFcvfdd8vs2bNNIPviiy/K6NGjZfPmzRIbG3vM9t99952MHTtWnnrqKfnjH/8ob7/9towZM0bWrFkjffv2rfXz6rKPVS39qAd+94P/8ZaH1NN6NW3rfB69HG/bujyup7fVmdF68eZt3d873a6697iqbevyuGxb/bbONne/r7Hq4Hw+b962sT+f2QXF8tySTbI8eb1oswVIpFZQxg7tKA9c1FuaiH4OHBwjfGB/Z9tjj+0n+h3uDdt64/d9WTXbVtXmjXmMON5ju2viqOmRPUAD2KFDh8q//vUvU9bKd+jQQW677Ta5//77j9n+mmuukby8PPnkk09ct5166qkyYMAAExBXVlhYaC5O2dnZ5vGTk5OlWbNmx2zftGlTad++vau8ZcuWahs/PDzcPJaTro9eeUlJfT36nBqYa++z0/bt26WkpKTKxw0JCamw7Y4dO6SoqHzlnsqCgoKkS5curnJqaqoUFBRUuW1gYKB07fr7KUddy/3IkSNVbqsfxO7du7vKe/bsMe1enR49eriu7927V3Jzc6vdtlu3bq6ddv/+/aZ9qqP11XqrtLQ0ycrKqnbbzp07S3BwsGvbnTt3SvPmzSt8QJw6derkmjCTkZEhhw4dqvZxneMRlW6n21cnPj5eIiLKUx9lZmbKgQMHqt22Xbt2rn1QX5fWuTpt27aVyMhIcz0nJ0f27dtX7bZt2rSRqKjyWev6Puj7UR3dL1u0aGGu5+fny+7du6vdtnXr1tKyZUtzXfcx3dcqc+7v+l7ExJRPONLPn74X1YmOjnZtW1xcLCkpKdVuq69LX5/Sz5p+5qqj731cXJyrXlu3bq12W30f9P1w+u2336rdtqGPEU66j+m+VtUxoqSsTLLyS+RQXqEcyi+WzEKHHJQoWbJhv6TnFMjJQfskUBwS1rylXDc8QbrE/H5s4xhR9TEiPT1dDh8+XO22tTlGOPd37Uhxfu45RtR8jHDS7XT7+hwj9HOk7V7V8d1fjxGeiCO0jXR/GTRokKvdG/MYoZ99jQ9139f3ymt7ZrVhf/zxR5k2bZrrNm2gUaNGyffff1/l3+jt2pPrTntyFy9eXOX22oM7ffr0Y27XxqlqJ9APlfNg59yuul8H+ma7zyDWA1PlnVB3YD0A6EHTeQBzblvTTuh+gNNta9oJ3bfV53EP3t3pAb/yttV9qen74L6tHqCr22FV5cetaYfVbZ0fBH3cmr7UdFvnF5VuqwfpmrZ1vne6rdZB27+q06z6Xmg7O9tXL9XRg43zfdb9oaZt9WDjDHy1rjVtq8+v+4bSNqhpW31dzvbX11XTtrpPOPcBffyattW2de5bui/UtK22o3Of1cevalvn/n7w4EHXwVsfv6bHdf+1rl9UNW2rny/n+6nXa9pW6+rcz/QzXNO2+rzabk41bdvQx4gyh0huYYnklATKd7uOSFpOkblkZ+yTrPwiyTxSIjmFpabn1VVfCZD9ZeUH9yZHU2/1bxMqZ/dvK8FBJRXqzzGi4rbO907rUFPgW5tjhHN/18BYjxOKY0TNxwh3zs9NXY8R2sZa76qO7754jHDS7Zz7jhVxhMPhMPV1/3w25jGips+nV/XMauStv1506MDw4cNdt997773y9ddfyw8//FDlGzRv3jwz1MBp5syZJmCt6ldrdT2z+mVbVaTfGMMM9ECnvyrdPwjefCrBF04h6sFAPxja7lX1zHrjqTNf2Na5v2tvjjPA8OdhBvpaDh8pkX1ZBeay93Ce7MsqdJX3ZxXIgZwCM3nLbG9C06PPIzUfmp3bntMzRiacFCxBpUdMj4n7ccaJY4T993e2/X1bPb47v1crH9+94XPvy8MM0tPTTc+3+4+A2j5uXT/LGsy2atXK+3tmPUF/ybj/6nHSA35VB/3KqgqE6rKtvtF6kNPncr//RB/X09vaUVXtjsbl3N/14ol2t/KzYQ62R4plb6YGpkdkrwaomUeOBqrO/wukqKQ2476aVBustm4WKu1ahEnbKL2EH71e/n+H6Ahp3SzEnPLMzCys1XGNY4R993eU00Cxtsd3b/j+9JVty47u73q/c5vG3O9rE6O5thUL6XgZbZjKPapado5jqUxvr8v2AFAfOQXF5b2pzgDVFagWyF4NVjMLXJkD6is6IrhCgNq2RZi00/+PBq5tokIlNKi8x686dZkkAQC+yNJgVocMDB48WJYvX24yEjgPzFqeOnVqlX+jwxH0/jvvvNN127JlyyoMUwCAmhwpKnUFpL/3oh5x9bLq7TmFVY9Fq63IsKDywNTZkxoVJnFRYdKuxe/BanhIzYEqAOD4LB9moJO5JkyYIEOGDJFhw4aZ1Fw6eH3SpEnm/vHjx5txtTqRS91xxx1y5plnyowZM+Siiy6SBQsWyOrVq+XVV1+1+JUA8AaFJaVmLKorMHWe9s/UHtXy65n5xSf0HBEhgSYgdQamcUeD1bYtfv+/WWiQRzsGnJOVAMDfWB7MaqotHVD80EMPmTQMmmJryZIlrtQamiLCfUzGiBEjTG7Zv//97/K3v/3NpH3QTAZ1yTELwJ6KS8vkQE6hOeXvPkbVNRQg64hk5FY9Y7e2QoICygNS5+l+tzGq5T2s4dI8PMhrFiPQ46Om4HGfYQwA/sTyYFbpkILqhhWsWLHimNuuuuoqcwHgO0rLHJKRW+gKTN0DVGcva3pOoUllVV+awqpN87Aqx6g6e1lbNg3xmkAVAGCTYBaAb9OZ/wfzio6e6tdT/kcDVrfe1bTs31NU1UdAE5HYyLCKk6jcTvvr/5oZIEA3BAD4DIJZACfEPUXVnsx8+W1XuuSWHZb92YX1SFFVvZpSVOn/sZGhEhTof6fZddKsru6jSdE1QwxDDQD4G4JZAD6Rosqf6co+1a3uAwC+jmAW8GPuKarcU1Xp6f/9jZiiSk/7lweqpKgCAJwYglnAj1JU/T6pqnFSVOnkqmYBxdIzPkbaR0d4PEUVAMD/8C0D2FBJaZmkeSBFlbP3tELPag0pqnT8pqaIio09ds10AAAaA8EsYLMUVdrbeiCnoFFSVJkVqo6WW5GiCgBgAwSzgAeRogoAgIZFMAs0Qooq5yQq99P/+7NJUYXGERQUZC4A4I84+gE2SVGlk6x0aICOZQWcdGxyly5dWM4WgN8imAWOk6LKGbTmkqIKAACvQzALv0pRtT/7yO+pqkzg2jApqsKDA2sco0qKKgAAGgffrrC14tIyOVBNiqrUjBxJz1tvJlydCD2trwGp+0z/CkMBosIkKjyYmf+whKZDS01NlcOHD7OcLQC/RDAL26ao0v/TcwpPOEVVnNtp/spjVPX/lqSogpcrKCiQwsJCq6sBAJYgmIWPp6gKPXqqnxRVAAD4IoJZNEqgqmNQNUg1Y1UtSFHVJjJUpCBb2sW14bQrAAA+jGAWDZKiyjmRqjwLQOOmqNJym6hQCQ2qfuZ/+bKqOSdUBwAA4P0IZlFjiqr9lcao6u05J5qiKjSoyklUzjGqpKgCAAC1RTDrRzyRoioiJLBCYBrnlk+VFFUAAKChEVX4UIoqnTBV1RhVZwaAjNyGSVHl7Emt3LuqwwCahwcx8x/wsMDAQHMBAH9EMGsDnkhRFRzYxCyVqgGpSVVFiirAFnSCY9euXVnOFoDfIpj1gpn/2mNqTve796g2eIqqigEqKaoAAIAvIJht5ED1cH6RbEnPl/WHDsj+7PKVqtx7V3VYQFFp46Wo0v8112pQID02AADA9xDMNiINUgc/vrxBUlRVN0b1eCmqAPg2TUO3a9culrMF4LcIZhuRBpmtmoaYla6qEhkWVH7a361XlRRVAOrqyJEjZklbAPBHBLON7IK+cZKVmycJbaKlvQappKgCAABoMERTjezRS/uYWcaxsbGc/gMAAGhgRFcAAACwLYJZAAAA2BbBLAAAAGyLYBYAbE5X5mNMPgB/xQQwALAxDWK7d+/OcrYA/BZHPgAAANgWwSwAAABsi2AWAGzM4XDInj17JC0tzVwHAH9DMAsANqYBbF5enlnSlmAWgD8imAUAAIBtEcwCAADAtghmAQAAYFsEswAAALAtglkAAADYlt+tAOac7Zudne2R5ysrK5OcnBwJCwtjdR4Pot2tQbtb0+a5ubkmo4Ee14KC/O6wbhn2d2vQ7v7R7tlH47TaZGnxu6OevhGqQ4cOVlcFAAAAx4nboqKiatpEmjj8LDGh/rLYu3evREZGSpMmTTzyy0ID5127dknz5s0b/flQjna3Bu1uDdrdGrS7NWh3/2h3h8NhAtl27dodtyfY73pmtUHi4+M9/rz6xvOh8zza3Rq0uzVod2vQ7tag3X2/3aOO0yPrxGATAAAA2BbBLAAAAGyLYLaRhYaGysMPP2z+h+fQ7tag3a1Bu1uDdrcG7W6NUC9ud7+bAAYAAADfQc8sAAAAbItgFgAAALZFMAsAAADbIpgFAACAbRHMetgll1wiHTt2NGsbt23bVq677jqzIhkaz44dO+T666+Xzp07S3h4uHTt2tXMyCwqKrK6aj7viSeekBEjRkhERIS0aNHC6ur4rFdeeUUSEhLMceWUU06RpKQkq6vk07755hu5+OKLzcpEupLk4sWLra6SX3jqqadk6NChZgXP2NhYGTNmjGzevNnqavm8WbNmSb9+/VyLJQwfPlw+//xz8SYEsx521llnybvvvms+gB988IFs27ZNrrzySqur5dM2bdpkljH+97//LRs2bJAXXnhBZs+eLX/729+srprP0x8MV111ldx8881WV8VnLVy4UO6++27zA23NmjXSv39/GT16tBw4cMDqqvmsvLw80876IwKe8/XXX8utt94qq1atkmXLlklxcbGcd9555v1A49FVU//xj3/Ijz/+KKtXr5azzz5bLr30UvN96i1IzWWxjz/+2Py6LCwslODgYKur4zeeffZZ82tz+/btVlfFLyQmJsqdd94pmZmZVlfF52hPrPZW/etf/zJl/eGm66ffdtttcv/991tdPZ+nPbOLFi0yx3F4Vnp6uumh1SD3D3/4g9XV8SstW7Y036N61tMb0DNroUOHDslbb71lTsMSyHpWVlaW+TACdu/51t6SUaNGuW4LCAgw5e+//97SugGeOI4rjuWeU1paKgsWLDC94TrcwFsQzFrgvvvuk6ZNm0qrVq0kNTVVPvroI6ur5Fe2bt0qL7/8stx4441WVwU4IRkZGebLpU2bNhVu1/L+/fstqxfQ2PQMhJ7tOe2006Rv375WV8fnrV+/Xpo1a2ZW/7rpppvM2YiTTjpJvAXBbAPQU3l6qqmmi47bdLrnnntk7dq1snTpUgkMDJTx48cLoz0av93Vnj175PzzzzfjOG+44QbL6u5v7Q4ADUnHzv7yyy+mlxCNr2fPnvLTTz/JDz/8YOZATJgwQTZu3CjegjGzDTRu5+DBgzVu06VLFwkJCTnm9t27d5vxbd99951Xddn7Yrtr1oiRI0fKqaeeasZw6ulYeGZ/Z8xs4w0z0EwR77//foUxm/pFo23NWZ/Gx5hZz5s6darZtzWrhGapgefpUCbNDKQTq71BkNUV8AUxMTHmUt9TJUongKHx2l17ZDWTxODBg2Xu3LkEshbt72hY+oNB9+nly5e7gik9pmhZv/ABX6J9bzqxUX88rFixgkDWQmVlZV4VtxDMepB2zycnJ8vpp58u0dHRJi3Xgw8+aH7d0CvbeDSQ1R7ZTp06yXPPPWd6Fp3i4uIsrZuv0zHhOtFR/9exnXqaSnXr1s2Mv8KJ07Rc2hM7ZMgQGTZsmLz44otmcsakSZOsrprPys3NNWPvnVJSUsy+rRORNI84Gm9owdtvv216ZTXXrHNceFRUlMkhjsYxbdo0ueCCC8y+nZOTY94D/THxxRdfiNfQYQbwjHXr1jnOOussR8uWLR2hoaGOhIQEx0033eTYvXu31VXzaXPnztWhNFVe0LgmTJhQZbt/9dVXVlfNp7z88suOjh07OkJCQhzDhg1zrFq1yuoq+TTdf6var3V/R+Op7jiux3g0nsmTJzs6depkji8xMTGOc845x7F06VKHN2HMLAAAAGyLgYMAAACwLYJZAAAA2BbBLAAAAGyLYBYAAAC2RTALAAAA2yKYBQAAgG0RzAIAAMC2CGYBAABgWwSzAAAAsC2CWQDwAhMnTpQxY8Z49DkTExOlRYsWHn1OAGhoBLMAAACwLYJZAPAyI0eOlNtvv13uvfdeadmypcTFxckjjzxSYZsmTZrIrFmz5IILLpDw8HDp0qWLvP/++677V6xYYbbJzMx03fbTTz+Z23bs2GHunzRpkmRlZZnb9FL5OQDADghmAcALzZs3T5o2bSo//PCDPPPMM/Loo4/KsmXLKmzz4IMPyhVXXCE///yzjBs3Tq699lr59ddfa/X4I0aMkBdffFGaN28u+/btM5e//vWvjfRqAKDxEMwCgBfq16+fPPzww9K9e3cZP368DBkyRJYvX15hm6uuukqmTJkiPXr0kMcee8xs8/LLL9fq8UNCQiQqKsr0yGrPr16aNWvWSK8GABoPwSwAeGkw665t27Zy4MCBCrcNHz78mHJte2YBwFcQzAKAFwoODq5Q1h7UsrKyWv99QED54d3hcLhuKy4ubsAaAoB3IJgFAJtatWrVMeXevXub6zExMeZ/HQvrPgGs8lCD0tJSj9QVABoLwSwA2NR7770nc+bMkd9++82Mr01KSpKpU6ea+7p16yYdOnQwGQq2bNkin376qcyYMaPC3yckJEhubq4Zi5uRkSH5+fkWvRIAqD+CWQCwqenTp8uCBQvM+No33nhD3nnnHTnppJNcwxS0vGnTJnP/008/LY8//vgxGQ1uuukmueaaa0xPrmZNAAC7aeJwH1AFALAFHUO7aNEij68aBgDehp5ZAAAA2BbBLAAAAGwryOoKAADqjhFiAFCOnlkAAADYFsEsAAAAbItgFgAAALZFMAsAAADbIpgFAACAbRHMAgAAwLYIZgEAAGBbBLMAAAAQu/r/Ewi2IXDFPmsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 800x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "class LeakyReLU(brainstate.nn.Module):\n",
    "    \"\"\"Leaky ReLU activation: y = max(alpha * x, x)\"\"\"\n",
    "    \n",
    "    def __init__(self, negative_slope=0.01):\n",
    "        super().__init__()\n",
    "        self.negative_slope = negative_slope\n",
    "    \n",
    "    def update(self, x):\n",
    "        return jnp.where(x > 0, x, self.negative_slope * x)\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return f\"LeakyReLU(negative_slope={self.negative_slope})\"\n",
    "\n",
    "# Test the activation\n",
    "activation = LeakyReLU(negative_slope=0.1)\n",
    "x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])\n",
    "y = activation(x)\n",
    "\n",
    "print(\"Activation:\", activation)\n",
    "print(f\"Input:  {x}\")\n",
    "print(f\"Output: {y}\")\n",
    "\n",
    "# Visualize\n",
    "x_plot = jnp.linspace(-3, 3, 100)\n",
    "y_plot = activation(x_plot)\n",
    "\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.plot(x_plot, y_plot, linewidth=2, label='LeakyReLU(0.1)')\n",
    "plt.axhline(0, color='gray', linestyle='--', alpha=0.3)\n",
    "plt.axvline(0, color='gray', linestyle='--', alpha=0.3)\n",
    "plt.grid(alpha=0.3)\n",
    "plt.xlabel('Input')\n",
    "plt.ylabel('Output')\n",
    "plt.title('Leaky ReLU Activation Function')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "composition",
   "metadata": {},
   "source": [
    "## 3. Module Composition and Nesting\n",
    "\n",
    "The real power of modules comes from composing them into larger networks.\n",
    "\n",
    "### Sequential Composition\n",
    "\n",
    "Build a network by stacking layers sequentially:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "sequential",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:11.901818Z",
     "start_time": "2025-10-11T08:24:11.078282Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP Architecture:\n",
      "MLP(\n",
      "  layers=[\n",
      "    Linear(in_features=10, out_features=64, use_bias=True),\n",
      "    LeakyReLU(negative_slope=0.0),\n",
      "    Linear(in_features=64, out_features=32, use_bias=True),\n",
      "    LeakyReLU(negative_slope=0.0),\n",
      "    Linear(in_features=32, out_features=5, use_bias=True)\n",
      "  ],\n",
      "  layer_0=Linear(in_features=10, out_features=64, use_bias=True),\n",
      "  activation_0=LeakyReLU(negative_slope=0.0),\n",
      "  layer_1=Linear(in_features=64, out_features=32, use_bias=True),\n",
      "  activation_1=LeakyReLU(negative_slope=0.0),\n",
      "  layer_2=Linear(in_features=32, out_features=5, use_bias=True)\n",
      ")\n",
      "\n",
      "Input shape: (10,)\n",
      "Output shape: (5,)\n",
      "Output: [-0.49218872  0.5558434  -0.6296929   0.25295696  0.37388656]\n"
     ]
    }
   ],
   "source": [
    "class MLP(brainstate.nn.Module):\n",
    "    \"\"\"Multi-layer perceptron with customizable architecture.\"\"\"\n",
    "    \n",
    "    def __init__(self, layer_sizes, activation='relu'):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.layers = []\n",
    "        \n",
    "        # Create layers\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            # Add linear layer\n",
    "            layer = Linear(layer_sizes[i], layer_sizes[i+1])\n",
    "            setattr(self, f'layer_{i}', layer)  # Register as attribute\n",
    "            self.layers.append(layer)\n",
    "            \n",
    "            # Add activation (except for last layer)\n",
    "            if i < len(layer_sizes) - 2:\n",
    "                if activation == 'relu':\n",
    "                    act = LeakyReLU(negative_slope=0.0)  # Standard ReLU\n",
    "                else:\n",
    "                    act = LeakyReLU(negative_slope=0.01)\n",
    "                setattr(self, f'activation_{i}', act)\n",
    "                self.layers.append(act)\n",
    "    \n",
    "    def update(self, x):\n",
    "        \"\"\"Forward pass through all layers.\"\"\"\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)\n",
    "        return x\n",
    "\n",
    "# Create a 3-layer MLP\n",
    "brainstate.random.seed(0)\n",
    "mlp = MLP(layer_sizes=[10, 64, 32, 5])\n",
    "\n",
    "# Forward pass\n",
    "x = brainstate.random.randn(10)\n",
    "y = mlp(x)\n",
    "\n",
    "print(\"MLP Architecture:\")\n",
    "print(mlp)\n",
    "print(f\"\\nInput shape: {x.shape}\")\n",
    "print(f\"Output shape: {y.shape}\")\n",
    "print(f\"Output: {y}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "residual",
   "metadata": {},
   "source": [
    "### Residual Connections\n",
    "\n",
    "Implement skip connections for deeper networks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "residual_block",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:14.082398Z",
     "start_time": "2025-10-11T08:24:13.803132Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ResNet:\n",
      "ResNet(\n",
      "  input_proj=Linear(in_features=10, out_features=32, use_bias=True),\n",
      "  blocks=[\n",
      "    ResidualBlock(\n",
      "      linear1=Linear(in_features=32, out_features=32, use_bias=True),\n",
      "      activation=LeakyReLU(negative_slope=0.0),\n",
      "      linear2=Linear(in_features=32, out_features=32, use_bias=True)\n",
      "    ),\n",
      "    ResidualBlock(\n",
      "      linear1=Linear(in_features=32, out_features=32, use_bias=True),\n",
      "      activation=LeakyReLU(negative_slope=0.0),\n",
      "      linear2=Linear(in_features=32, out_features=32, use_bias=True)\n",
      "    ),\n",
      "    ResidualBlock(\n",
      "      linear1=Linear(in_features=32, out_features=32, use_bias=True),\n",
      "      activation=LeakyReLU(negative_slope=0.0),\n",
      "      linear2=Linear(in_features=32, out_features=32, use_bias=True)\n",
      "    )\n",
      "  ],\n",
      "  block_0=ResidualBlock(...),\n",
      "  block_1=ResidualBlock(...),\n",
      "  block_2=ResidualBlock(...),\n",
      "  output_proj=Linear(in_features=32, out_features=5, use_bias=True)\n",
      ")\n",
      "\n",
      "Output shape: (5,)\n"
     ]
    }
   ],
   "source": [
    "class ResidualBlock(brainstate.nn.Module):\n",
    "    \"\"\"Residual block: y = F(x) + x\"\"\"\n",
    "    \n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Two linear layers with activation in between\n",
    "        self.linear1 = Linear(dim, dim)\n",
    "        self.activation = LeakyReLU(0.0)\n",
    "        self.linear2 = Linear(dim, dim)\n",
    "    \n",
    "    def update(self, x):\n",
    "        # Compute residual\n",
    "        residual = x\n",
    "        \n",
    "        # Forward through layers\n",
    "        out = self.linear1(x)\n",
    "        out = self.activation(out)\n",
    "        out = self.linear2(out)\n",
    "        \n",
    "        # Add residual\n",
    "        return out + residual\n",
    "\n",
    "class ResNet(brainstate.nn.Module):\n",
    "    \"\"\"Simple ResNet with multiple residual blocks.\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, n_blocks=3):\n",
    "        super().__init__()\n",
    "        \n",
    "        # Input projection\n",
    "        self.input_proj = Linear(input_dim, hidden_dim)\n",
    "        \n",
    "        # Residual blocks\n",
    "        self.blocks = []\n",
    "        for i in range(n_blocks):\n",
    "            block = ResidualBlock(hidden_dim)\n",
    "            setattr(self, f'block_{i}', block)\n",
    "            self.blocks.append(block)\n",
    "        \n",
    "        # Output projection\n",
    "        self.output_proj = Linear(hidden_dim, output_dim)\n",
    "    \n",
    "    def update(self, x):\n",
    "        # Project to hidden dimension\n",
    "        x = self.input_proj(x)\n",
    "        \n",
    "        # Pass through residual blocks\n",
    "        for block in self.blocks:\n",
    "            x = block(x)\n",
    "        \n",
    "        # Project to output\n",
    "        x = self.output_proj(x)\n",
    "        return x\n",
    "\n",
    "# Create ResNet\n",
    "brainstate.random.seed(0)\n",
    "resnet = ResNet(input_dim=10, hidden_dim=32, output_dim=5, n_blocks=3)\n",
    "\n",
    "# Forward pass\n",
    "x = brainstate.random.randn(10)\n",
    "y = resnet(x)\n",
    "\n",
    "print(\"ResNet:\")\n",
    "print(resnet)\n",
    "print(f\"\\nOutput shape: {y.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "size_inference",
   "metadata": {},
   "source": [
    "## 4. Automatic Input/Output Size Inference\n",
    "\n",
    "One of BrainState's most powerful features is **automatic input/output size inference**. Every `brainstate.nn.Module` instance has `in_size` and `out_size` properties that track the shape of data flowing through the module (excluding the batch dimension).\n",
    "\n",
    "### Key Concepts\n",
    "\n",
    "✅ **`in_size`**: Input shape without batch dimension  \n",
    "✅ **`out_size`**: Output shape without batch dimension (automatically inferred)  \n",
    "✅ **Automatic propagation**: When `in_size` is known, `out_size` is computed automatically  \n",
    "✅ **Sequential composition**: Output size of one layer becomes input size of next layer\n",
    "\n",
    "This mechanism eliminates the need to manually calculate dimensions through network layers, making it much easier to build complex architectures."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "size_basic",
   "metadata": {},
   "source": [
    "### Example 1: Basic Size Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "size_basic_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:14.681746Z",
     "start_time": "2025-10-11T08:24:14.087343Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Layer: Linear(\n",
      "  in_size=(10,),\n",
      "  out_size=(5,),\n",
      "  w_mask=None,\n",
      "  weight=ParamState(\n",
      "    value={\n",
      "      'bias': ShapedArray(float32[5]),\n",
      "      'weight': ShapedArray(float32[10,5])\n",
      "    }\n",
      "  )\n",
      ")\n",
      "Input size:  (10,)\n",
      "Output size: (5,)\n",
      "\n",
      "Input shape:  (32, 10)  (batch_size=32, in_features=10)\n",
      "Output shape: (32, 5)  (batch_size=32, out_features=5)\n",
      "\n",
      "Note: in_size and out_size DO NOT include the batch dimension!\n"
     ]
    }
   ],
   "source": [
    "# Create a linear layer with explicit in_size and out_size\n",
    "layer = brainstate.nn.Linear(in_size=(10,), out_size=(5,))\n",
    "\n",
    "print(\"Layer:\", layer)\n",
    "print(f\"Input size:  {layer.in_size}\")\n",
    "print(f\"Output size: {layer.out_size}\")\n",
    "\n",
    "# Forward pass with batch dimension\n",
    "x = brainstate.random.randn(32, 10)  # (batch_size, in_features)\n",
    "y = layer(x)\n",
    "\n",
    "print(f\"\\nInput shape:  {x.shape}  (batch_size=32, in_features=10)\")\n",
    "print(f\"Output shape: {y.shape}  (batch_size=32, out_features=5)\")\n",
    "print(\"\\nNote: in_size and out_size DO NOT include the batch dimension!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "size_conv",
   "metadata": {},
   "source": [
    "### Example 2: Size Inference with Convolution\n",
    "\n",
    "Convolution layers automatically compute output spatial dimensions based on:\n",
    "- Input spatial size\n",
    "- Kernel size\n",
    "- Stride\n",
    "- Padding mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "size_conv_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:15.016234Z",
     "start_time": "2025-10-11T08:24:14.686828Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conv2d Layer:\n",
      "  in_size:  (28, 28, 3)\n",
      "  out_size: (28, 28, 32)\n",
      "\n",
      "  Input:  (H, W, C) = (28, 28, 3)\n",
      "  Output: (H', W', C') = (28, 28, 32)\n",
      "\n",
      "With 'SAME' padding and stride=1, spatial dimensions are preserved!\n",
      "\n",
      "With 'VALID' padding and stride=2:\n",
      "  in_size:  (28, 28, 3)\n",
      "  out_size: (13, 13, 32)\n",
      "  Spatial dimensions are reduced!\n"
     ]
    }
   ],
   "source": [
    "# Create a 2D convolution layer\n",
    "conv = brainstate.nn.Conv2d(\n",
    "    in_size=(28, 28, 3),      # (height, width, channels)\n",
    "    out_channels=32,\n",
    "    kernel_size=3,\n",
    "    stride=1,\n",
    "    padding='SAME'\n",
    ")\n",
    "\n",
    "print(\"Conv2d Layer:\")\n",
    "print(f\"  in_size:  {conv.in_size}\")\n",
    "print(f\"  out_size: {conv.out_size}\")\n",
    "print(f\"\\n  Input:  (H, W, C) = {conv.in_size}\")\n",
    "print(f\"  Output: (H', W', C') = {conv.out_size}\")\n",
    "print(\"\\nWith 'SAME' padding and stride=1, spatial dimensions are preserved!\")\n",
    "\n",
    "# Test with different padding\n",
    "conv_valid = brainstate.nn.Conv2d(\n",
    "    in_size=(28, 28, 3),\n",
    "    out_channels=32,\n",
    "    kernel_size=3,\n",
    "    stride=2,\n",
    "    padding='VALID'\n",
    ")\n",
    "\n",
    "print(f\"\\nWith 'VALID' padding and stride=2:\")\n",
    "print(f\"  in_size:  {conv_valid.in_size}\")\n",
    "print(f\"  out_size: {conv_valid.out_size}\")\n",
    "print(\"  Spatial dimensions are reduced!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "size_pooling",
   "metadata": {},
   "source": [
    "### Example 3: Size Inference with Pooling and Flatten\n",
    "\n",
    "Pooling layers reduce spatial dimensions, and Flatten layers convert multi-dimensional tensors to 1D vectors. BrainState tracks all these transformations automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "size_pooling_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:15.028083Z",
     "start_time": "2025-10-11T08:24:15.021240Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MaxPool2d Layer:\n",
      "  in_size:  (28, 28, 32)  (H=28, W=28, C=32)\n",
      "  out_size: (14, 14, 32)  (H=14, W=14, C=32)\n",
      "  Spatial dimensions reduced by 2x!\n",
      "\n",
      "Flatten Layer:\n",
      "  in_size:  (14, 14, 32)  (3D tensor)\n",
      "  out_size: (6272,)  (1D vector)\n",
      "  Total elements: 6272 = 6272\n"
     ]
    }
   ],
   "source": [
    "# MaxPool reduces spatial dimensions\n",
    "pool = brainstate.nn.MaxPool2d(\n",
    "    in_size=(28, 28, 32),\n",
    "    kernel_size=(2, 2),\n",
    "    stride=(2, 2),\n",
    "    channel_axis=-1\n",
    ")\n",
    "\n",
    "print(\"MaxPool2d Layer:\")\n",
    "print(f\"  in_size:  {pool.in_size}  (H=28, W=28, C=32)\")\n",
    "print(f\"  out_size: {pool.out_size}  (H=14, W=14, C=32)\")\n",
    "print(\"  Spatial dimensions reduced by 2x!\")\n",
    "\n",
    "# Flatten converts to 1D\n",
    "flatten = brainstate.nn.Flatten(in_size=(14, 14, 32))\n",
    "\n",
    "print(f\"\\nFlatten Layer:\")\n",
    "print(f\"  in_size:  {flatten.in_size}  (3D tensor)\")\n",
    "print(f\"  out_size: {flatten.out_size}  (1D vector)\")\n",
    "print(f\"  Total elements: {14 * 14 * 32} = {flatten.out_size[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sequential_intro",
   "metadata": {},
   "source": [
    "## 5. Sequential Composition and Deep Networks\n",
    "\n",
    "`brainstate.nn.Sequential` is a powerful container that chains multiple modules together. It automatically propagates `out_size` from one layer to the `in_size` of the next layer, enabling effortless construction of deep networks.\n",
    "\n",
    "### The `.desc()` Pattern\n",
    "\n",
    "For layers that need to infer their `in_size` from the previous layer, BrainState provides the `.desc()` method, which creates a **layer descriptor** that will be instantiated when the input size becomes available.\n",
    "\n",
    "```python\n",
    "# Instead of:\n",
    "brainstate.nn.Linear(in_size=(10,), out_size=(5,))\n",
    "\n",
    "# Use descriptor in Sequential:\n",
    "brainstate.nn.Linear.desc(out_size=5)  # in_size will be inferred!\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sequential_basic",
   "metadata": {},
   "source": [
    "### Example 1: Simple Sequential Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "sequential_basic_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:16.027483Z",
     "start_time": "2025-10-11T08:24:15.053887Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential MLP:\n",
      "Sequential(\n",
      "  in_size=(10,),\n",
      "  out_size=(5,),\n",
      "  layers=[\n",
      "    Linear(\n",
      "      in_size=(10,),\n",
      "      out_size=(64,),\n",
      "      w_mask=None,\n",
      "      weight=ParamState(\n",
      "        value={\n",
      "          'bias': ShapedArray(float32[64]),\n",
      "          'weight': ShapedArray(float32[10,64])\n",
      "        }\n",
      "      )\n",
      "    ),\n",
      "    ReLU(),\n",
      "    Linear(\n",
      "      in_size=(64,),\n",
      "      out_size=(32,),\n",
      "      w_mask=None,\n",
      "      weight=ParamState(\n",
      "        value={\n",
      "          'bias': ShapedArray(float32[32]),\n",
      "          'weight': ShapedArray(float32[64,32])\n",
      "        }\n",
      "      )\n",
      "    ),\n",
      "    ReLU(),\n",
      "    Linear(\n",
      "      in_size=(32,),\n",
      "      out_size=(5,),\n",
      "      w_mask=None,\n",
      "      weight=ParamState(\n",
      "        value={\n",
      "          'bias': ShapedArray(float32[5]),\n",
      "          'weight': ShapedArray(float32[32,5])\n",
      "        }\n",
      "      )\n",
      "    )\n",
      "  ]\n",
      ")\n",
      "\n",
      "Input size:  (10,)\n",
      "Output size: (5,)\n",
      "\n",
      "Forward pass:\n",
      "  Input:  (8, 10)\n",
      "  Output: (8, 5)\n"
     ]
    }
   ],
   "source": [
    "# Build a simple MLP with Sequential\n",
    "brainstate.random.seed(42)\n",
    "\n",
    "mlp = brainstate.nn.Sequential(\n",
    "    brainstate.nn.Linear((10,), (64,)),        # First layer needs explicit in_size\n",
    "    brainstate.nn.ReLU(),                      # Element-wise, preserves shape\n",
    "    brainstate.nn.Linear.desc(out_size=32),    # in_size inferred from previous layer\n",
    "    brainstate.nn.ReLU(),\n",
    "    brainstate.nn.Linear.desc(out_size=5)      # Final output layer\n",
    ")\n",
    "\n",
    "print(\"Sequential MLP:\")\n",
    "print(mlp)\n",
    "print(f\"\\nInput size:  {mlp.in_size}\")\n",
    "print(f\"Output size: {mlp.out_size}\")\n",
    "\n",
    "# Test forward pass\n",
    "x = brainstate.random.randn(8, 10)  # batch of 8 samples\n",
    "y = mlp(x)\n",
    "print(f\"\\nForward pass:\")\n",
    "print(f\"  Input:  {x.shape}\")\n",
    "print(f\"  Output: {y.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cnn_example",
   "metadata": {},
   "source": [
    "### Example 2: CNN Network with Automatic Size Propagation\n",
    "\n",
    "Let's build a complete CNN for image classification, demonstrating how `in_size` and `out_size` propagate through convolutional, pooling, flattening, and fully-connected layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cnn_example_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:17.445590Z",
     "start_time": "2025-10-11T08:24:16.065619Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN Network Architecture:\n",
      "CNNNet(\n",
      "  layer=Sequential(\n",
      "    in_size=(28, 28, 3),\n",
      "    out_size=(10,),\n",
      "    layers=[\n",
      "      Conv2d(\n",
      "        in_size=(28, 28, 3),\n",
      "        out_size=(28, 28, 32),\n",
      "        channel_first=False,\n",
      "        channels_last=True,\n",
      "        in_channels=3,\n",
      "        out_channels=32,\n",
      "        stride=(1, 1),\n",
      "        kernel_size=(3, 3),\n",
      "        lhs_dilation=(1, 1),\n",
      "        rhs_dilation=(1, 1),\n",
      "        groups=1,\n",
      "        dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),\n",
      "        padding=SAME,\n",
      "        kernel_shape=(3, 3, 3, 32),\n",
      "        w_mask=None,\n",
      "        w_initializer=XavierNormal(\n",
      "          scale=1.0,\n",
      "          mode='fan_avg',\n",
      "          in_axis=-2,\n",
      "          out_axis=-1,\n",
      "          distribution='truncated_normal',\n",
      "          rng=RandomState([1825841970 3512247751]),\n",
      "          unit=Unit(10.0^0)\n",
      "        ),\n",
      "        b_initializer=None,\n",
      "        weight=ParamState(\n",
      "          value={\n",
      "            'weight': ShapedArray(float32[3,3,3,32])\n",
      "          }\n",
      "        )\n",
      "      ),\n",
      "      ReLU(),\n",
      "      MaxPool2d(\n",
      "        in_size=(28, 28, 32),\n",
      "        out_size=(14, 14, 32),\n",
      "        init_value=-inf,\n",
      "        computation=<function max at 0x0000012125AD8360>,\n",
      "        pool_dim=2,\n",
      "        return_indices=False,\n",
      "        kernel_size=(2, 2),\n",
      "        stride=(2, 2),\n",
      "        padding=VALID,\n",
      "        channel_axis=-1\n",
      "      ),\n",
      "      Conv2d(\n",
      "        in_size=(14, 14, 32),\n",
      "        out_size=(14, 14, 64),\n",
      "        channel_first=False,\n",
      "        channels_last=True,\n",
      "        in_channels=32,\n",
      "        out_channels=64,\n",
      "        stride=(1, 1),\n",
      "        kernel_size=(3, 3),\n",
      "        lhs_dilation=(1, 1),\n",
      "        rhs_dilation=(1, 1),\n",
      "        groups=1,\n",
      "        dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),\n",
      "        padding=SAME,\n",
      "        kernel_shape=(3, 3, 32, 64),\n",
      "        w_mask=None,\n",
      "        w_initializer=XavierNormal(\n",
      "          scale=1.0,\n",
      "          mode='fan_avg',\n",
      "          in_axis=-2,\n",
      "          out_axis=-1,\n",
      "          distribution='truncated_normal',\n",
      "          rng=RandomState([1825841970 3512247751]),\n",
      "          unit=Unit(10.0^0)\n",
      "        ),\n",
      "        b_initializer=None,\n",
      "        weight=ParamState(\n",
      "          value={\n",
      "            'weight': ShapedArray(float32[3,3,32,64])\n",
      "          }\n",
      "        )\n",
      "      ),\n",
      "      ReLU(),\n",
      "      MaxPool2d(\n",
      "        in_size=(14, 14, 64),\n",
      "        out_size=(7, 7, 64),\n",
      "        init_value=-inf,\n",
      "        computation=<function max at 0x0000012125AD8360>,\n",
      "        pool_dim=2,\n",
      "        return_indices=False,\n",
      "        kernel_size=(2, 2),\n",
      "        stride=(2, 2),\n",
      "        padding=VALID,\n",
      "        channel_axis=-1\n",
      "      ),\n",
      "      Flatten(\n",
      "        in_size=(7, 7, 64),\n",
      "        out_size=(3136,),\n",
      "        start_axis=0,\n",
      "        end_axis=-1\n",
      "      ),\n",
      "      Linear(\n",
      "        in_size=(3136,),\n",
      "        out_size=(1024,),\n",
      "        w_mask=None,\n",
      "        weight=ParamState(\n",
      "          value={\n",
      "            'bias': ShapedArray(float32[1024]),\n",
      "            'weight': ShapedArray(float32[3136,1024])\n",
      "          }\n",
      "        )\n",
      "      ),\n",
      "      ReLU(),\n",
      "      Linear(\n",
      "        in_size=(1024,),\n",
      "        out_size=(512,),\n",
      "        w_mask=None,\n",
      "        weight=ParamState(\n",
      "          value={\n",
      "            'bias': ShapedArray(float32[512]),\n",
      "            'weight': ShapedArray(float32[1024,512])\n",
      "          }\n",
      "        )\n",
      "      ),\n",
      "      ReLU(),\n",
      "      Linear(\n",
      "        in_size=(512,),\n",
      "        out_size=(10,),\n",
      "        w_mask=None,\n",
      "        weight=ParamState(\n",
      "          value={\n",
      "            'bias': ShapedArray(float32[10]),\n",
      "            'weight': ShapedArray(float32[512,10])\n",
      "          }\n",
      "        )\n",
      "      )\n",
      "    ]\n",
      "  )\n",
      ")\n",
      "\n",
      "Network input size:  None\n",
      "Network output size: None\n",
      "\n",
      "============================================================\n",
      "Size transformations through the network:\n",
      "============================================================\n",
      "Layer  0 (Conv2d         ): (28, 28, 3)          -> (28, 28, 32)        \n",
      "Layer  1 (ReLU           ): None                 -> None                \n",
      "Layer  2 (MaxPool2d      ): (28, 28, 32)         -> (14, 14, 32)        \n",
      "Layer  3 (Conv2d         ): (14, 14, 32)         -> (14, 14, 64)        \n",
      "Layer  4 (ReLU           ): None                 -> None                \n",
      "Layer  5 (MaxPool2d      ): (14, 14, 64)         -> (7, 7, 64)          \n",
      "Layer  6 (Flatten        ): (7, 7, 64)           -> (3136,)             \n",
      "Layer  7 (Linear         ): (3136,)              -> (1024,)             \n",
      "Layer  8 (ReLU           ): None                 -> None                \n",
      "Layer  9 (Linear         ): (1024,)              -> (512,)              \n",
      "Layer 10 (ReLU           ): None                 -> None                \n",
      "Layer 11 (Linear         ): (512,)               -> (10,)               \n"
     ]
    }
   ],
   "source": [
    "class CNNNet(brainstate.nn.Module):\n",
    "    \"\"\"Convolutional Neural Network for image classification.\"\"\"\n",
    "    \n",
    "    def __init__(self, in_size):\n",
    "        super().__init__()\n",
    "        self.layer = brainstate.nn.Sequential(\n",
    "            # Convolutional block 1\n",
    "            brainstate.nn.Conv2d(in_size, out_channels=32, kernel_size=(3, 3), \n",
    "                               stride=(1, 1), padding='SAME'),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),\n",
    "            \n",
    "            # Convolutional block 2\n",
    "            brainstate.nn.Conv2d.desc(out_channels=64, kernel_size=(3, 3), \n",
    "                                    stride=(1, 1), padding='SAME'),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),\n",
    "            \n",
    "            # Flatten and fully-connected layers\n",
    "            brainstate.nn.Flatten.desc(),\n",
    "            brainstate.nn.Linear.desc(out_size=1024),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.Linear.desc(out_size=512),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.Linear.desc(out_size=10)\n",
    "        )\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.layer(x)\n",
    "\n",
    "# Create CNN with image size (28, 28, 3)\n",
    "example_image = brainstate.random.normal(size=(28, 28, 3))\n",
    "cnn = CNNNet(example_image.shape)\n",
    "\n",
    "print(\"CNN Network Architecture:\")\n",
    "print(cnn)\n",
    "print(f\"\\nNetwork input size:  {cnn.in_size}\")\n",
    "print(f\"Network output size: {cnn.out_size}\")\n",
    "\n",
    "# Trace size transformations through the network\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"Size transformations through the network:\")\n",
    "print(\"=\"*60)\n",
    "for i, layer in enumerate(cnn.layer.layers):\n",
    "    if hasattr(layer, 'in_size') and hasattr(layer, 'out_size'):\n",
    "        print(f\"Layer {i:2d} ({layer.__class__.__name__:15s}): \"\n",
    "              f\"{str(layer.in_size):20s} -> {str(layer.out_size):20s}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cnn_forward",
   "metadata": {},
   "source": [
    "### Example 3: Forward Pass Through CNN\n",
    "\n",
    "Let's actually run a forward pass and see how data flows through the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "cnn_forward_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:18.048845Z",
     "start_time": "2025-10-11T08:24:17.508357Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch shape: (4, 28, 28, 3)\n",
      "  (batch_size, height, width, channels) = (4, 28, 28, 3)\n",
      "\n",
      "Output shape: (4, 10)\n",
      "  (batch_size, num_classes) = (4, 10)\n",
      "\n",
      "Output logits for first sample:\n",
      "[ 0.13889284  0.49220082 -0.6353385   0.36826375 -0.55741405 -0.22296685\n",
      "  1.5445015   0.7295152   0.04205686 -0.02874903]\n"
     ]
    }
   ],
   "source": [
    "# Create a batch of images\n",
    "batch_size = 4\n",
    "batch_images = brainstate.random.normal(size=(batch_size, 28, 28, 3))\n",
    "\n",
    "print(f\"Input batch shape: {batch_images.shape}\")\n",
    "print(f\"  (batch_size, height, width, channels) = ({batch_size}, 28, 28, 3)\")\n",
    "\n",
    "# Forward pass\n",
    "output = cnn(batch_images)\n",
    "\n",
    "print(f\"\\nOutput shape: {output.shape}\")\n",
    "print(f\"  (batch_size, num_classes) = ({batch_size}, 10)\")\n",
    "print(f\"\\nOutput logits for first sample:\")\n",
    "print(output[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sequential_benefits",
   "metadata": {},
   "source": [
    "### Benefits of Automatic Size Inference\n",
    "\n",
    "The automatic `in_size`/`out_size` inference system provides several key advantages:\n",
    "\n",
    "1. **🎯 No manual dimension calculations**: You don't need to compute output sizes after each layer\n",
    "2. **🔧 Easy architecture modifications**: Change one layer without updating all subsequent layers\n",
    "3. **🐛 Early error detection**: Shape mismatches are caught at construction time\n",
    "4. **📊 Built-in documentation**: Network architecture is self-documenting with size information\n",
    "5. **🚀 Rapid prototyping**: Quickly experiment with different architectures\n",
    "\n",
    "### Key Pattern: `.desc()` for Layer Descriptors\n",
    "\n",
    "When building networks with `Sequential`, use the `.desc()` pattern for all layers except the first:\n",
    "\n",
    "```python\n",
    "brainstate.nn.Sequential(\n",
    "    FirstLayer(in_size, ...),         # Explicit in_size\n",
    "    SecondLayer.desc(...),            # in_size inferred\n",
    "    ThirdLayer.desc(...),             # in_size inferred\n",
    "    # ...\n",
    ")\n",
    "```\n",
    "\n",
    "This pattern ensures that:\n",
    "- The first layer knows the input size\n",
    "- All subsequent layers automatically infer their input sizes\n",
    "- The network construction is clean and maintainable"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "advanced_sequential",
   "metadata": {},
   "source": [
    "### Example 4: Complex Architecture with Mixed Layer Types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "advanced_sequential_code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T08:24:19.906254Z",
     "start_time": "2025-10-11T08:24:18.049433Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complex Network:\n",
      "Input size: (32, 32, 3)\n",
      "After features: (8, 8, 64)\n",
      "After flatten: (4096,)\n",
      "Final output: (10,)\n",
      "\n",
      "Forward pass: (2, 32, 32, 3) -> (2, 10)\n"
     ]
    }
   ],
   "source": [
    "# Build a more complex network with different layer types\n",
    "class ComplexNet(brainstate.nn.Module):\n",
    "    \"\"\"Complex network demonstrating various layer types.\"\"\"\n",
    "    \n",
    "    def __init__(self, in_size):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.features = brainstate.nn.Sequential(\n",
    "            # Initial conv block\n",
    "            brainstate.nn.Conv2d(in_size, out_channels=16, kernel_size=3, padding='SAME'),\n",
    "            brainstate.nn.ReLU(),\n",
    "            \n",
    "            # Strided conv (reduces spatial size)\n",
    "            brainstate.nn.Conv2d.desc(out_channels=32, kernel_size=3, stride=2, padding='SAME'),\n",
    "            brainstate.nn.ReLU(),\n",
    "            \n",
    "            # Another conv + pool\n",
    "            brainstate.nn.Conv2d.desc(out_channels=64, kernel_size=3, padding='SAME'),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),\n",
    "        )\n",
    "        \n",
    "        self.classifier = brainstate.nn.Sequential(\n",
    "            brainstate.nn.Flatten(in_size=self.features.out_size),\n",
    "            brainstate.nn.Linear.desc(out_size=256),\n",
    "            brainstate.nn.ReLU(),\n",
    "            brainstate.nn.Linear.desc(out_size=10),\n",
    "        )\n",
    "    \n",
    "    def update(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "\n",
    "# Create network\n",
    "net = ComplexNet(in_size=(32, 32, 3))\n",
    "\n",
    "print(\"Complex Network:\")\n",
    "print(f\"Input size: {net.features.in_size}\")\n",
    "print(f\"After features: {net.features.out_size}\")\n",
    "print(f\"After flatten: {net.classifier.layers[0].out_size}\")\n",
    "print(f\"Final output: {net.classifier.out_size}\")\n",
    "\n",
    "# Test\n",
    "x = brainstate.random.randn(2, 32, 32, 3)\n",
    "y = net(x)\n",
    "print(f\"\\nForward pass: {x.shape} -> {y.shape}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
