{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f6a7b8",
   "metadata": {},
   "source": [
    "# Graph Compilation & Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5f6a7b8c9",
   "metadata": {},
   "source": [
    "In `braintrace`, models are compiled into an `ETraceGraph` -- an intermediate representation that captures the\n",
    "structural relationships between weight parameters, ETP primitives (the operations that connect inputs to\n",
    "hidden states), and hidden state groups. This compilation step is what enables efficient online learning:\n",
    "by analyzing the computation graph, `braintrace` can automatically determine which weights influence which\n",
    "hidden states, and how eligibility traces should propagate.\n",
    "\n",
    "The `show_graph()` method visualizes these relationships, providing a human-readable summary of:\n",
    "\n",
    "- **Hidden groups**: clusters of hidden states that evolve together (e.g., the membrane potential and\n",
    "  adaptation current of a neuron population)\n",
    "- **Weight-primitive-hidden connections**: which weight parameters are associated with which hidden groups\n",
    "  through which ETP primitives\n",
    "- **Non-ETP weights**: parameters that exist in the model but do not participate in online learning\n",
    "\n",
    "Understanding the compiled graph is essential for debugging model structure, verifying that the correct\n",
    "parameters are included in online learning, and optimizing model design."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d4e5f6a7b8c9d0",
   "metadata": {},
   "source": [
    "## Single-Layer RNN\n",
    "\n",
    "We start with the simplest case: a single recurrent layer followed by a linear readout.\n",
    "The `ValinaRNNCell` contains one hidden state and one recurrent weight, and the `Linear`\n",
    "readout has its own weight that feeds into the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e5f6a7b8c9d0e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5f6a7b8c9d0e1f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SingleLayerRNN(brainstate.nn.Module):\n",
    "    def __init__(self, n_in, n_rec, n_out):\n",
    "        super().__init__()\n",
    "        self.rnn = braintrace.nn.ValinaRNNCell(n_in, n_rec)\n",
    "        self.out = braintrace.nn.Linear(n_rec, n_out)\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.out(self.rnn(x))\n",
    "\n",
    "\n",
    "model = SingleLayerRNN(10, 32, 5)\n",
    "brainstate.nn.init_all_states(model)\n",
    "\n",
    "# Compile the graph and visualize\n",
    "algo = braintrace.D_RTRL(model)\n",
    "algo.compile_graph(jnp.zeros(10))\n",
    "algo.show_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a7b8c9d0e1f2a3",
   "metadata": {},
   "source": [
    "The output shows:\n",
    "\n",
    "- **Hidden Group 0**: the hidden state of the `ValinaRNNCell` (path `('rnn', 'h')`)\n",
    "- **Weight 0**: the recurrent weight inside the RNN cell, associated with Hidden Group 0\n",
    "- **Weight 1**: the readout weight, which may or may not appear depending on whether the readout\n",
    "  layer uses ETP primitives\n",
    "\n",
    "This tells us that `D_RTRL` will maintain an eligibility trace for the recurrent weight,\n",
    "tracking how it influences the hidden state over time."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b8c9d0e1f2a3b4",
   "metadata": {},
   "source": [
    "## Understanding ETraceGraph\n",
    "\n",
    "The compiled graph is an `ETraceGraph` named tuple with several key fields:\n",
    "\n",
    "| Field | Type | Description |\n",
    "|---|---|---|\n",
    "| `module_info` | `ModuleInfo` | Jaxpr and state mappings extracted from the model |\n",
    "| `hidden_groups` | `Sequence[HiddenGroup]` | Discovered hidden state groups |\n",
    "| `hid_path_to_group` | `Dict[Path, HiddenGroup]` | Mapping from hidden state path to its group |\n",
    "| `hidden_param_op_relations` | `Sequence[HiddenParamOpRelation]` | Weight-primitive-hidden connections |\n",
    "| `hidden_perturb` | `HiddenPerturbation` or `None` | Perturbation structure for Jacobian computation |\n",
    "\n",
    "Each `HiddenGroup` records a cluster of hidden states that are updated together in one\n",
    "recurrent step. Each `HiddenParamOpRelation` records the connection between a weight\n",
    "parameter and the hidden groups it feeds into through an ETP primitive.\n",
    "\n",
    "Let's inspect these programmatically:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8c9d0e1f2a3b4c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = algo.graph\n",
    "\n",
    "print(\"=== Hidden Groups ===\")\n",
    "for g in graph.hidden_groups:\n",
    "    print(f\"  Group {g.index}: {g.num_state} state(s), shape {g.varshape}\")\n",
    "    print(f\"    Paths: {g.hidden_paths}\")\n",
    "\n",
    "print(\"\\n=== Weight-Primitive-Hidden Relations ===\")\n",
    "for i, r in enumerate(graph.hidden_param_op_relations):\n",
    "    print(f\"  Relation {i}:\")\n",
    "    print(f\"    Weight path: {r.weight_path}\")\n",
    "    print(f\"    Primitive: {r.primitive}\")\n",
    "    print(f\"    Hidden groups: {[g.index for g in r.hidden_groups]}\")\n",
    "\n",
    "print(f\"\\n=== Perturbation ===\")\n",
    "print(f\"  Has perturbation: {graph.hidden_perturb is not None}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9d0e1f2a3b4c5d6",
   "metadata": {},
   "source": [
    "The `HiddenGroup.num_state` property returns the total number of state variables in the group,\n",
    "and `HiddenGroup.varshape` returns the shape of each state variable. The\n",
    "`HiddenParamOpRelation.primitive` field identifies which ETP primitive (e.g., `etp_matmul_p`)\n",
    "connects the weight to the hidden state."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0e1f2a3b4c5d6e7",
   "metadata": {},
   "source": [
    "## Two-Layer RNN\n",
    "\n",
    "With multiple recurrent layers, the graph becomes richer. Each layer introduces its own\n",
    "hidden group, and the compiler discovers which weights feed into which hidden groups.\n",
    "In a stacked RNN, each layer's recurrent weight is associated with only its own hidden group --\n",
    "the layers are structurally independent from the perspective of eligibility trace propagation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1f2a3b4c5d6e7f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TwoLayerRNN(brainstate.nn.Module):\n",
    "    def __init__(self, n_in, n_rec, n_out):\n",
    "        super().__init__()\n",
    "        self.rnn1 = braintrace.nn.GRUCell(n_in, n_rec)\n",
    "        self.rnn2 = braintrace.nn.GRUCell(n_rec, n_rec)\n",
    "        self.out = braintrace.nn.Linear(n_rec, n_out)\n",
    "\n",
    "    def update(self, x):\n",
    "        h1 = self.rnn1(x)\n",
    "        h2 = self.rnn2(h1)\n",
    "        return self.out(h2)\n",
    "\n",
    "\n",
    "model2 = TwoLayerRNN(10, 32, 5)\n",
    "brainstate.nn.init_all_states(model2)\n",
    "\n",
    "algo2 = braintrace.D_RTRL(model2)\n",
    "algo2.compile_graph(jnp.zeros(10))\n",
    "algo2.show_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a3b4c5d6e7f8a9",
   "metadata": {},
   "source": [
    "Notice that:\n",
    "\n",
    "- Each GRU layer creates its own hidden group (the GRU hidden state `h`)\n",
    "- Each layer's recurrent and input weights are associated with that layer's hidden group\n",
    "- The readout weight forms its own relation if it uses an ETP primitive\n",
    "\n",
    "This structural analysis is what allows `D_RTRL` to maintain separate eligibility traces\n",
    "for each layer, avoiding the need to backpropagate through time across the entire network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b4c5d6e7f8a9b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inspect the two-layer graph programmatically\n",
    "graph2 = algo2.graph\n",
    "\n",
    "print(f\"Number of hidden groups: {len(graph2.hidden_groups)}\")\n",
    "print(f\"Number of weight-hidden relations: {len(graph2.hidden_param_op_relations)}\")\n",
    "\n",
    "print(\"\\nHidden groups:\")\n",
    "for g in graph2.hidden_groups:\n",
    "    print(f\"  Group {g.index}: {g.hidden_paths}\")\n",
    "\n",
    "print(\"\\nRelations:\")\n",
    "for i, r in enumerate(graph2.hidden_param_op_relations):\n",
    "    groups = [g.index for g in r.hidden_groups]\n",
    "    print(f\"  Weight {i}: {r.weight_path} -> hidden group(s) {groups}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4c5d6e7f8a9b0c1",
   "metadata": {},
   "source": [
    "## Convolutional Network\n",
    "\n",
    "ETP primitives also support convolutional operations via `braintrace.nn.Conv2d`. When a\n",
    "convolutional layer feeds into a recurrent layer, the compiler discovers the connection\n",
    "between the convolution kernel and the downstream hidden state. This demonstrates the\n",
    "generality of the graph compilation -- it works with any ETP primitive, not just matrix\n",
    "multiplication."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d6e7f8a9b0c1d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvRNN(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = braintrace.nn.Conv2d(1, 8, kernel_size=3, padding='SAME')\n",
    "        self.rnn = braintrace.nn.ValinaRNNCell(8 * 28 * 28, 64)\n",
    "        self.out = braintrace.nn.Linear(64, 10)\n",
    "\n",
    "    def update(self, x):\n",
    "        # x: (1, 28, 28) -- single-channel 28x28 image\n",
    "        features = self.conv(x).reshape(-1)\n",
    "        return self.out(self.rnn(features))\n",
    "\n",
    "\n",
    "model3 = ConvRNN()\n",
    "brainstate.nn.init_all_states(model3)\n",
    "\n",
    "algo3 = braintrace.D_RTRL(model3)\n",
    "algo3.compile_graph(jnp.zeros((1, 28, 28)))\n",
    "algo3.show_graph()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e7f8a9b0c1d2e3",
   "metadata": {},
   "source": [
    "In this model:\n",
    "\n",
    "- The `Conv2d` kernel weight is discovered as an ETP parameter because `braintrace.nn.Conv2d`\n",
    "  uses the `etp_conv` primitive internally\n",
    "- The RNN's recurrent weight uses `etp_matmul`\n",
    "- Both are associated with the RNN's hidden group, since the convolution output flows into\n",
    "  the recurrent computation\n",
    "- The readout `Linear` layer also uses an ETP primitive\n",
    "\n",
    "This shows how the compiler traces data flow across different layer types to build\n",
    "the complete eligibility trace graph."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8a9b0c1d2e3f4",
   "metadata": {},
   "source": [
    "## Using `compile_etrace_graph` Directly\n",
    "\n",
    "For advanced users who want to inspect the graph without wrapping the model in an algorithm\n",
    "like `D_RTRL`, `braintrace` exposes the `compile_etrace_graph()` function directly. This\n",
    "is useful for:\n",
    "\n",
    "- Debugging model structure before training\n",
    "- Verifying that ETP primitives are correctly placed\n",
    "- Building custom online learning algorithms on top of the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8a9b0c1d2e3f4a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_direct = SingleLayerRNN(10, 32, 5)\n",
    "brainstate.nn.init_all_states(model_direct)\n",
    "\n",
    "graph_direct = braintrace.compile_etrace_graph(model_direct, jnp.zeros(10))\n",
    "\n",
    "print(f\"Number of hidden groups: {len(graph_direct.hidden_groups)}\")\n",
    "print(f\"Number of relations: {len(graph_direct.hidden_param_op_relations)}\")\n",
    "print(f\"Has perturbation: {graph_direct.hidden_perturb is not None}\")\n",
    "\n",
    "print(\"\\nGraph fields:\")\n",
    "for key in graph_direct.dict().keys():\n",
    "    print(f\"  {key}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b0c1d2e3f4a5b6",
   "metadata": {},
   "source": [
    "The `compile_etrace_graph()` function returns the same `ETraceGraph` named tuple that is\n",
    "stored internally by `D_RTRL` and other algorithms. You can use it to build custom\n",
    "training loops or to programmatically analyze model structure."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0c1d2e3f4a5b6c7",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, we covered the graph compilation and visualization tools in `braintrace`:\n",
    "\n",
    "- **`compile_graph()`** (on algorithm objects) and **`compile_etrace_graph()`** (standalone function)\n",
    "  analyze the model's computation graph to discover the structural relationships between weights,\n",
    "  ETP primitives, and hidden states\n",
    "- **`show_graph()`** provides a human-readable summary of the compiled graph, showing hidden groups,\n",
    "  weight-hidden associations, and non-ETP parameters\n",
    "- The compiled graph reveals **which weights participate in online learning** -- only weights used\n",
    "  through ETP primitives (`braintrace.nn.Linear`, `braintrace.nn.Conv2d`, etc.) are included\n",
    "- **Multi-layer** and **convolutional** models create richer graph structures with multiple hidden\n",
    "  groups and cross-layer relationships\n",
    "- The `ETraceGraph` named tuple can be inspected programmatically for custom analysis or to build\n",
    "  custom online learning algorithms\n",
    "\n",
    "Understanding the compiled graph is a key step in verifying that your model is correctly structured\n",
    "for online learning with `braintrace`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}