{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4",
   "metadata": {},
   "source": [
    "# Compiler Internals\n",
    "\n",
    "The braintrace compiler transforms a `brainstate.nn.Module` into an `ETraceGraph` -- a structured representation that captures all the relationships between weight parameters, hidden states, and ETP primitives needed for online learning.\n",
    "\n",
    "The compilation pipeline consists of four stages:\n",
    "\n",
    "1. **`extract_module_info`** -- Trace the model and extract its Jaxpr (JAX's intermediate representation)\n",
    "2. **`find_hidden_groups_from_minfo`** -- Identify groups of recurrent hidden states that are mutually connected\n",
    "3. **`find_hidden_param_op_relations_from_minfo`** -- Discover how ETP primitives connect weight parameters to hidden states\n",
    "4. **`add_hidden_perturbation_from_minfo`** -- Build the perturbation structure for computing hidden-to-hidden Jacobians\n",
    "\n",
    "Understanding these internals helps you debug compilation issues, inspect the computational graph, and customize behavior when working with non-standard model architectures."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We will use a two-layer vanilla RNN as a running example throughout this notebook. This model is simple enough to inspect manually, yet complex enough to demonstrate multi-group hidden state discovery."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3d4e5f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:10.216760Z",
     "iopub.status.busy": "2026-04-17T09:26:10.216613Z",
     "iopub.status.idle": "2026-04-17T09:26:14.208067Z",
     "shell.execute_reply": "2026-04-17T09:26:14.207434Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model created and states initialized.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace\n",
    "\n",
    "\n",
    "class TwoLayerRNN(brainstate.nn.Module):\n",
    "    \"\"\"A two-layer vanilla RNN with a linear readout.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.rnn1 = braintrace.nn.ValinaRNNCell(10, 32)\n",
    "        self.rnn2 = braintrace.nn.ValinaRNNCell(32, 16)\n",
    "        self.out = braintrace.nn.Linear(16, 5)\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.out(self.rnn2(self.rnn1(x)))\n",
    "\n",
    "\n",
    "model = TwoLayerRNN()\n",
    "brainstate.nn.init_all_states(model)\n",
    "\n",
    "# A dummy input matching the first layer's input dimension\n",
    "dummy_input = jnp.zeros(10)\n",
    "\n",
    "print(\"Model created and states initialized.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e5f6a7",
   "metadata": {},
   "source": [
    "## Step 1: ModuleInfo Extraction\n",
    "\n",
    "`extract_module_info(model, *args)` is the first stage of the compiler. It:\n",
    "\n",
    "1. Wraps the model in a `StatefulFunction` and traces it with JAX to produce a Jaxpr.\n",
    "2. Collects all states from the module hierarchy via `brainstate.graph.states(model)`.\n",
    "3. Classifies each state as a **hidden state** (`brainstate.HiddenState`) or a **weight parameter** (`brainstate.ParamState`).\n",
    "4. Builds bidirectional mappings between Jaxpr variables and their module paths.\n",
    "\n",
    "The result is a `ModuleInfo` named tuple containing the Jaxpr, state mappings, and variable-to-path dictionaries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e5f6a7b8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.210443Z",
     "iopub.status.busy": "2026-04-17T09:26:14.209992Z",
     "iopub.status.idle": "2026-04-17T09:26:14.229654Z",
     "shell.execute_reply": "2026-04-17T09:26:14.228820Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jaxpr equations: 12\n",
      "Compiled model states: 5\n",
      "Hidden states: 2\n",
      "Weight parameters: 3\n"
     ]
    }
   ],
   "source": [
    "minfo = braintrace.extract_module_info(model, dummy_input)\n",
    "\n",
    "print(f\"Jaxpr equations: {len(minfo.jaxpr.eqns)}\")\n",
    "print(f\"Compiled model states: {len(minfo.compiled_model_states)}\")\n",
    "print(f\"Hidden states: {len(minfo.hidden_path_to_invar)}\")\n",
    "print(f\"Weight parameters: {len(minfo.weight_path_to_invars)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a7b8c9",
   "metadata": {},
   "source": [
    "### Inspecting State Mappings\n",
    "\n",
    "The `ModuleInfo` maintains separate mappings for hidden states and weight parameters. Each mapping connects a module path (a tuple of attribute names) to a Jaxpr variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a7b8c9d0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.232057Z",
     "iopub.status.busy": "2026-04-17T09:26:14.231601Z",
     "iopub.status.idle": "2026-04-17T09:26:14.236516Z",
     "shell.execute_reply": "2026-04-17T09:26:14.235461Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Hidden State Paths ===\n",
      "  ('rnn1', 'h')  ->  Var(id=132835946131648):float32[32]\n",
      "  ('rnn2', 'h')  ->  Var(id=132832132601216):float32[16]\n",
      "\n",
      "=== Weight Parameter Paths ===\n",
      "  ('rnn1', 'W', 'weight')  ->  2 variable(s)\n",
      "  ('rnn2', 'W', 'weight')  ->  2 variable(s)\n",
      "  ('out', 'weight')  ->  2 variable(s)\n"
     ]
    }
   ],
   "source": [
    "print(\"=== Hidden State Paths ===\")\n",
    "for path, var in minfo.hidden_path_to_invar.items():\n",
    "    print(f\"  {path}  ->  {var}\")\n",
    "\n",
    "print()\n",
    "print(\"=== Weight Parameter Paths ===\")\n",
    "for path, invars in minfo.weight_path_to_invars.items():\n",
    "    print(f\"  {path}  ->  {len(invars)} variable(s)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c9d0e1",
   "metadata": {},
   "source": [
    "## Step 2: Hidden Group Discovery\n",
    "\n",
    "`find_hidden_groups_from_minfo(minfo)` identifies groups of recurrent hidden states that are interconnected through the computational graph.\n",
    "\n",
    "**Algorithm:**\n",
    "\n",
    "1. For each hidden state invar, trace forward through the Jaxpr to find which hidden state outvars it reaches.\n",
    "2. Build a connectivity graph: if hidden state A's invar reaches hidden state B's outvar, they are connected.\n",
    "3. Merge overlapping connected components into groups.\n",
    "4. Filter by shape compatibility -- hidden states in the same group must have compatible shapes.\n",
    "5. Filter by layer membership -- hidden states from different sequential layers (paths diverging at numeric indices) are placed in separate groups.\n",
    "\n",
    "For each group, the compiler also builds a **transition Jaxpr** that computes:\n",
    "\n",
    "$$h_1^t, h_2^t, \\ldots = f(h_1^{t-1}, h_2^{t-1}, \\ldots, x^t)$$\n",
    "\n",
    "In our two-layer RNN, we expect two separate hidden groups -- one for each RNN layer -- because their hidden states are in different sequential layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9d0e1f2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.239316Z",
     "iopub.status.busy": "2026-04-17T09:26:14.239009Z",
     "iopub.status.idle": "2026-04-17T09:26:14.244242Z",
     "shell.execute_reply": "2026-04-17T09:26:14.243347Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of hidden groups discovered: 2\n",
      "\n",
      "Group 0:\n",
      "  Number of states: 1\n",
      "  Variable shape: (32,)\n",
      "  Hidden paths:\n",
      "    - ('rnn1', 'h')\n",
      "  Transition Jaxpr equations: 5\n",
      "  Transition const vars: 3\n",
      "\n",
      "Group 1:\n",
      "  Number of states: 1\n",
      "  Variable shape: (16,)\n",
      "  Hidden paths:\n",
      "    - ('rnn2', 'h')\n",
      "  Transition Jaxpr equations: 5\n",
      "  Transition const vars: 3\n",
      "\n"
     ]
    }
   ],
   "source": [
    "hidden_groups, hid_path_to_group = braintrace.find_hidden_groups_from_minfo(minfo)\n",
    "\n",
    "print(f\"Number of hidden groups discovered: {len(hidden_groups)}\")\n",
    "print()\n",
    "\n",
    "for g in hidden_groups:\n",
    "    print(f\"Group {g.index}:\")\n",
    "    print(f\"  Number of states: {g.num_state}\")\n",
    "    print(f\"  Variable shape: {g.varshape}\")\n",
    "    print(f\"  Hidden paths:\")\n",
    "    for path in g.hidden_paths:\n",
    "        print(f\"    - {path}\")\n",
    "    print(f\"  Transition Jaxpr equations: {len(g.transition_jaxpr.eqns)}\")\n",
    "    print(f\"  Transition const vars: {len(g.transition_jaxpr_constvars)}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0e1f2a3",
   "metadata": {},
   "source": [
    "## Step 3: Finding ETP Relations\n",
    "\n",
    "`find_hidden_param_op_relations_from_minfo(minfo, hid_path_to_group)` connects ETP primitives to their weight parameters and the hidden states they influence.\n",
    "\n",
    "**Algorithm:**\n",
    "\n",
    "For each equation in the Jaxpr:\n",
    "\n",
    "1. **Primitive identification**: Check `eqn.primitive in ETP_PRIMITIVES` (type identity, not string matching). This is robust -- renaming a function or wrapping it in `jax.jit` does not break identification.\n",
    "\n",
    "2. **Weight extraction**: Extract the weight variable from `eqn.invars` (index 1 for matmul/conv, index 0 for element-wise).\n",
    "\n",
    "3. **Backward tracing**: Trace the weight variable backward through the Jaxpr (following producer equations) to find the originating `ParamState`. This handles cases where weight transformations (e.g., `weight_fn`, masking) are applied before the primitive.\n",
    "\n",
    "4. **Forward BFS**: From the primitive's output variable, perform a breadth-first search forward through the Jaxpr to find reachable hidden-state outvars.\n",
    "\n",
    "5. **Shape compatibility**: Filter out hidden outvars whose shapes are not broadcast-compatible with the primitive output.\n",
    "\n",
    "6. **Transition Jaxpr**: Build a sub-Jaxpr mapping `y -> h` for each connected hidden group, used later for computing `dh/dy`.\n",
    "\n",
    "The result is a sequence of `HiddenParamOpRelation` named tuples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1f2a3b4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.246415Z",
     "iopub.status.busy": "2026-04-17T09:26:14.246233Z",
     "iopub.status.idle": "2026-04-17T09:26:14.258460Z",
     "shell.execute_reply": "2026-04-17T09:26:14.257837Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of ETP relations discovered: 2\n",
      "\n",
      "Relation 0:\n",
      "  Primitive: etp_mv\n",
      "  Weight path: ('rnn1', 'W', 'weight')\n",
      "  x_var: Var(id=132831597789632):float32[42]\n",
      "  y_var: Var(id=132835947956608):float32[32]\n",
      "  Connected hidden groups: [0]\n",
      "  Connected hidden paths:\n",
      "    - ('rnn1', 'h')\n",
      "  Equation params: {'has_bias': True}\n",
      "\n",
      "Relation 1:\n",
      "  Primitive: etp_mv\n",
      "  Weight path: ('rnn2', 'W', 'weight')\n",
      "  x_var: Var(id=132831595499392):float32[48]\n",
      "  y_var: Var(id=132831595499648):float32[16]\n",
      "  Connected hidden groups: [1]\n",
      "  Connected hidden paths:\n",
      "    - ('rnn2', 'h')\n",
      "  Equation params: {'has_bias': True}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.\n",
      "  _emit_no_relation_diag(\n"
     ]
    }
   ],
   "source": [
    "relations = braintrace.find_hidden_param_op_relations_from_minfo(minfo, hid_path_to_group)\n",
    "\n",
    "print(f\"Number of ETP relations discovered: {len(relations)}\")\n",
    "print()\n",
    "\n",
    "for i, r in enumerate(relations):\n",
    "    print(f\"Relation {i}:\")\n",
    "    print(f\"  Primitive: {r.primitive.name}\")\n",
    "    print(f\"  Weight path: {r.weight_path}\")\n",
    "    print(f\"  x_var: {r.x_var}\")\n",
    "    print(f\"  y_var: {r.y_var}\")\n",
    "    print(f\"  Connected hidden groups: {[g.index for g in r.hidden_groups]}\")\n",
    "    print(f\"  Connected hidden paths:\")\n",
    "    for path in r.connected_hidden_paths:\n",
    "        print(f\"    - {path}\")\n",
    "    print(f\"  Equation params: {r.eqn_params}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a3b4c5",
   "metadata": {},
   "source": [
    "## Step 4: Hidden Perturbation\n",
    "\n",
    "`add_hidden_perturbation_from_minfo(minfo)` builds the perturbation structure used to compute the hidden-to-hidden Jacobian $\\partial h^t / \\partial h^{t-1}$.\n",
    "\n",
    "**How it works:**\n",
    "\n",
    "For each hidden state outvar $h^t = f(x)$ in the Jaxpr, the compiler rewrites the equation to:\n",
    "\n",
    "$$\\hat{h}^t = f(x), \\quad h^t = \\hat{h}^t + \\Delta$$\n",
    "\n",
    "where $\\Delta$ is a new perturbation variable added to the Jaxpr's invars. By differentiating the perturbed Jaxpr with respect to $\\Delta$, we obtain:\n",
    "\n",
    "$$\\frac{\\partial L^t}{\\partial h^t} = \\frac{\\partial L^t}{\\partial \\Delta}$$\n",
    "\n",
    "This is used by the online learning algorithms (D-RTRL, ES-D-RTRL) to propagate eligibility traces through time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a3b4c5d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.260118Z",
     "iopub.status.busy": "2026-04-17T09:26:14.259942Z",
     "iopub.status.idle": "2026-04-17T09:26:14.264287Z",
     "shell.execute_reply": "2026-04-17T09:26:14.263681Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of perturbation variables: 2\n",
      "Number of perturbation paths: 2\n",
      "\n",
      "Perturbed hidden states:\n",
      "  Path: ('rnn1', 'h')\n",
      "  Perturbation var: Var(id=132832205799808):float32[32]\n",
      "  State type: HiddenState\n",
      "\n",
      "  Path: ('rnn2', 'h')\n",
      "  Perturbation var: Var(id=132832133512512):float32[16]\n",
      "  State type: HiddenState\n",
      "\n"
     ]
    }
   ],
   "source": [
    "perturb = braintrace.add_hidden_perturbation_from_minfo(minfo)\n",
    "\n",
    "print(f\"Number of perturbation variables: {len(perturb.perturb_vars)}\")\n",
    "print(f\"Number of perturbation paths: {len(perturb.perturb_hidden_paths)}\")\n",
    "print()\n",
    "\n",
    "print(\"Perturbed hidden states:\")\n",
    "for path, var, state in zip(\n",
    "    perturb.perturb_hidden_paths,\n",
    "    perturb.perturb_vars,\n",
    "    perturb.perturb_hidden_states,\n",
    "):\n",
    "    print(f\"  Path: {path}\")\n",
    "    print(f\"  Perturbation var: {var}\")\n",
    "    print(f\"  State type: {type(state).__name__}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4c5d6e7",
   "metadata": {},
   "source": [
    "## The Complete Pipeline\n",
    "\n",
    "`compile_etrace_graph(model, *args)` runs all four steps in sequence and returns an `ETraceGraph` named tuple containing everything the online learning algorithms need.\n",
    "\n",
    "The function also performs an additional step: it rewrites the Jaxpr to return extra intermediate variables (weight inputs, transition constants, etc.) that the graph executor needs at runtime."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5d6e7f8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.266104Z",
     "iopub.status.busy": "2026-04-17T09:26:14.265914Z",
     "iopub.status.idle": "2026-04-17T09:26:14.275217Z",
     "shell.execute_reply": "2026-04-17T09:26:14.274548Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hidden groups: 2\n",
      "ETP relations: 2\n",
      "Module info available: True\n",
      "Hidden perturbation available: True\n",
      "\n",
      "ETraceGraph fields: ['module_info', 'hidden_groups', 'hid_path_to_group', 'hidden_param_op_relations', 'hidden_perturb', 'diagnostics']\n"
     ]
    }
   ],
   "source": [
    "graph = braintrace.compile_etrace_graph(model, jnp.zeros(10))\n",
    "\n",
    "print(f\"Hidden groups: {len(graph.hidden_groups)}\")\n",
    "print(f\"ETP relations: {len(graph.hidden_param_op_relations)}\")\n",
    "print(f\"Module info available: {graph.module_info is not None}\")\n",
    "print(f\"Hidden perturbation available: {graph.hidden_perturb is not None}\")\n",
    "print()\n",
    "print(f\"ETraceGraph fields: {list(graph._asdict().keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e7f8a9",
   "metadata": {},
   "source": [
    "## How Primitive Identification Works\n",
    "\n",
    "A key design decision in braintrace is **type-based** primitive identification rather than **name-based** matching.\n",
    "\n",
    "### Old system (string matching)\n",
    "\n",
    "The old `ETraceOp` system identified weight operations by matching JIT function names as strings. This was fragile:\n",
    "\n",
    "- Renaming a function broke the match.\n",
    "- Wrapping a function in `jax.jit` changed the name.\n",
    "- Name collisions could cause false matches.\n",
    "\n",
    "### New system (type identity)\n",
    "\n",
    "The ETP system registers custom JAX primitives (`etp_mm_p`, `etp_mv_p`, `etp_elemwise_p`, `etp_conv_p`, `etp_sp_mm_p`, `etp_sp_mv_p`, `etp_lora_mm_p`, `etp_lora_mv_p`) and identifies them by checking:\n",
    "\n",
    "```python\n",
    "eqn.primitive in ETP_PRIMITIVES   # set membership, O(1)\n",
    "```\n",
    "\n",
    "This is robust because:\n",
    "\n",
    "- Primitive identity is a Python object reference, not a string.\n",
    "- Wrapping in `jax.jit` does not change the primitive type.\n",
    "- No name collisions are possible.\n",
    "\n",
    "### Weight variable extraction\n",
    "\n",
    "Once an ETP equation is found, the weight variable is extracted from `eqn.invars` at a position that depends on the primitive type. The same indices are recorded by `ETPPrimitiveSpec` (see `weight_invar_index` / `x_invar_index`) and can be queried at runtime through `braintrace.get_primitive_spec(prim)`:\n",
    "\n",
    "| Primitive | `invars[0]` | `invars[1]` | `invars[2]` | Notes |\n",
    "|---|---|---|---|---|\n",
    "| `etp_mm_p` / `etp_mv_p` | input `x` | weight `w` | bias `b` (optional) | `weight_invar_index=1` |\n",
    "| `etp_elemwise_p` | processed weight `y` | -- | -- | `weight_invar_index=0`, `x_invar_index=None` |\n",
    "| `etp_conv_p` | input `x` | kernel `w` | bias `b` (optional) | `weight_invar_index=1` |\n",
    "| `etp_sp_mm_p` / `etp_sp_mv_p` | input `x` | sparse `weight_data` | bias `b` (optional) | `weight_invar_index=1`; the static sparse pattern is in `eqn.params['sparse_mat']` |\n",
    "| `etp_lora_mm_p` / `etp_lora_mv_p` | input `x` | LoRA factor `B` | LoRA factor `A` (then bias `b` optional) | `weight_invar_index=1`; the originating `ParamState` holds both `B` *and* `A` as a dict, so the compiler back-traces from `B` and recovers both |\n",
    "\n",
    "The weight variable is then traced **backward** through the Jaxpr's producer map to find the originating `ParamState`, handling intermediate transformations like masking, weight standardization, or sign constraints."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8a9b0",
   "metadata": {},
   "source": [
    "## Debugging Compilation Issues\n",
    "\n",
    "When compilation produces unexpected results (e.g., missing relations, wrong group assignments), inspecting the raw Jaxpr is the most effective debugging tool.\n",
    "\n",
    "### Common issues\n",
    "\n",
    "- **Missing ETP relations**: A weight parameter uses a regular JAX op (e.g., `x @ w`) instead of an ETP primitive (e.g., `braintrace.matmul(x, w)`). The compiler only recognizes ETP primitives.\n",
    "- **Shape mismatches**: The output of an ETP primitive is not broadcast-compatible with the target hidden state. The compiler will warn and skip the connection.\n",
    "- **Hidden states in control flow**: Hidden states computed inside `jax.lax.scan`, `jax.lax.while_loop`, or `jax.lax.cond` are currently unsupported and will raise an error.\n",
    "\n",
    "### Inspecting the Jaxpr\n",
    "\n",
    "You can iterate over the Jaxpr equations and flag ETP primitives to verify the compiler sees what you expect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f8a9b0c1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.277060Z",
     "iopub.status.busy": "2026-04-17T09:26:14.276807Z",
     "iopub.status.idle": "2026-04-17T09:26:14.281318Z",
     "shell.execute_reply": "2026-04-17T09:26:14.280663Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jaxpr equations (ETP primitives marked with **):\n",
      "\n",
      "  [ 0]        convert_element_type: [(32,)] -> [(32,)]\n",
      "  [ 1]        concatenate: [(10,), (32,)] -> [(42,)]\n",
      "  [ 2] **ETP** etp_mv: [(42,), (42, 32), (32,)] -> [(32,)]\n",
      "  [ 3]        mul: [(32,), ()] -> [(32,)]\n",
      "  [ 4]        custom_jvp_call: [(32,)] -> [(32,)]\n",
      "  [ 5]        convert_element_type: [(16,)] -> [(16,)]\n",
      "  [ 6]        concatenate: [(32,), (16,)] -> [(48,)]\n",
      "  [ 7] **ETP** etp_mv: [(48,), (48, 16), (16,)] -> [(16,)]\n",
      "  [ 8]        mul: [(16,), ()] -> [(16,)]\n",
      "  [ 9]        custom_jvp_call: [(16,)] -> [(16,)]\n",
      "  [10] **ETP** etp_mv: [(16,), (16, 5), (5,)] -> [(5,)]\n",
      "  [11]        mul: [(5,), ()] -> [(5,)]\n"
     ]
    }
   ],
   "source": [
    "from braintrace._etrace_op import is_etp_primitive\n",
    "\n",
    "print(\"Jaxpr equations (ETP primitives marked with **):\\n\")\n",
    "for i, eqn in enumerate(minfo.jaxpr.eqns):\n",
    "    primitive_name = eqn.primitive.name\n",
    "    in_shapes = [\n",
    "        v.aval.shape if hasattr(v, 'aval') else 'literal'\n",
    "        for v in eqn.invars\n",
    "    ]\n",
    "    out_shapes = [v.aval.shape for v in eqn.outvars]\n",
    "    is_etp = is_etp_primitive(eqn.primitive)\n",
    "    marker = \"**ETP**\" if is_etp else \"      \"\n",
    "    print(f\"  [{i:2d}] {marker} {primitive_name}: {in_shapes} -> {out_shapes}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b0c1d2",
   "metadata": {},
   "source": [
    "### Verifying backward tracing\n",
    "\n",
    "If an ETP relation is missing, you can manually check whether the weight variable can be traced back to a `ParamState` by building the producer map and calling the internal tracing function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b0c1d2e3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.283593Z",
     "iopub.status.busy": "2026-04-17T09:26:14.283304Z",
     "iopub.status.idle": "2026-04-17T09:26:14.288545Z",
     "shell.execute_reply": "2026-04-17T09:26:14.287746Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Primitive: etp_mv\n",
      "  weight_invar_index: 1\n",
      "  Weight var: Var(id=132831598399488):float32[42,32]\n",
      "  Traced to ParamState: (('rnn1', 'W', 'weight'), ())\n",
      "\n",
      "Primitive: etp_mv\n",
      "  weight_invar_index: 1\n",
      "  Weight var: Var(id=132831595498880):float32[48,16]\n",
      "  Traced to ParamState: (('rnn2', 'W', 'weight'), ())\n",
      "\n",
      "Primitive: etp_mv\n",
      "  weight_invar_index: 1\n",
      "  Weight var: Var(id=132831595549632):float32[16,5]\n",
      "  Traced to ParamState: (('out', 'weight'), ())\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from braintrace._etrace_compiler.hid_param_op import _build_producer_map, _trace_var_to_param\n",
    "from braintrace import get_primitive_spec\n",
    "\n",
    "producers = _build_producer_map(minfo.jaxpr)\n",
    "\n",
    "for eqn in minfo.jaxpr.eqns:\n",
    "    if not is_etp_primitive(eqn.primitive):\n",
    "        continue\n",
    "    # Look up the spec rather than hard-coding the weight index per primitive.\n",
    "    spec = get_primitive_spec(eqn.primitive)\n",
    "    weight_var = eqn.invars[spec.weight_invar_index]\n",
    "\n",
    "    path = _trace_var_to_param(\n",
    "        weight_var, producers, minfo.invar_to_weight_path\n",
    "    )\n",
    "    print(f\"Primitive: {eqn.primitive.name}\")\n",
    "    print(f\"  weight_invar_index: {spec.weight_invar_index}\")\n",
    "    print(f\"  Weight var: {weight_var}\")\n",
    "    print(f\"  Traced to ParamState: {path}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd89dae5",
   "metadata": {},
   "source": [
    "## Compiler Diagnostics\n",
    "\n",
    "Every call to `compile_etrace_graph` annotates each weight/primitive decision with a `CompilationRecord`. The full list lives at `graph.diagnostics`. When the compiler skips a weight or merges hidden groups in a way you did not expect, this list is the first place to look -- it tells you *which* weight, *which* primitive, and *why*.\n",
    "\n",
    "Each `CompilationRecord` has these fields:\n",
    "\n",
    "| Field | Type | What it carries |\n",
    "|---|---|---|\n",
    "| `kind` | `DiagnosticKind` | Decision category, e.g. `RELATION_INCLUDED`, `RELATION_EXCLUDED_NON_TEMPORAL`, `RELATION_EXCLUDED_WEIGHT_TO_WEIGHT`, `TRANSITION_TAIL_BOUNDED`, `STATE_MISMATCH`. |\n",
    "| `level` | `DiagnosticLevel` | `INFO`, `WARNING`, or `ERROR`. |\n",
    "| `message` | `str` | Human-readable summary, including the weight path and primitive name. |\n",
    "| `primitive` | `Primitive \\| None` | The ETP primitive involved (if any). |\n",
    "| `weight_path` | `tuple[str, ...] \\| None` | Dotted path to the `ParamState` (e.g. `('cell', 'Wr', 'weight')`). |\n",
    "| `hidden_paths` | `tuple[tuple[str, ...], ...] \\| None` | Hidden-state paths the relation reaches. |\n",
    "| `context` | `dict` | Free-form extra info -- group indices, classification tags, etc. |\n",
    "\n",
    "The same diagnostics are emitted as `UserWarning`s during `compile_graph()` (so you see them on stderr without doing anything), but querying `graph.diagnostics` lets you filter, log, or assert on them programmatically. `graph.explain()` is a convenience that prints the records grouped by kind."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f850f7df",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:14.291367Z",
     "iopub.status.busy": "2026-04-17T09:26:14.291071Z",
     "iopub.status.idle": "2026-04-17T09:26:14.297971Z",
     "shell.execute_reply": "2026-04-17T09:26:14.297285Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Diagnostics: 3\n",
      "  [INFO   ] RELATION_INCLUDED: etp_mv(('rnn1', 'W', 'weight')) -> [0]\n",
      "  [INFO   ] RELATION_INCLUDED: etp_mv(('rnn2', 'W', 'weight')) -> [1]\n",
      "  [WARNING] RELATION_EXCLUDED_NON_TEMPORAL: ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.\n",
      "\n",
      "errors: 0, weight->weight exclusions: 0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(CompilationRecord(kind=relation_included, level=info, primitive='etp_mv', weight_path=('rnn1', 'W', 'weight'), hidden_paths=[('rnn1', 'h')], message=\"etp_mv(('rnn1', 'W', 'weight')) -> [0]\", context={'hidden_group_indices': (0,), 'path_classification': {('rnn1', 'h'): 'all_direct'}}),\n",
       " CompilationRecord(kind=relation_included, level=info, primitive='etp_mv', weight_path=('rnn2', 'W', 'weight'), hidden_paths=[('rnn2', 'h')], message=\"etp_mv(('rnn2', 'W', 'weight')) -> [1]\", context={'hidden_group_indices': (1,), 'path_classification': {('rnn2', 'h'): 'all_direct'}}),\n",
       " CompilationRecord(kind=relation_excluded_non_temporal, level=warning, primitive='etp_mv', weight_path=('out', 'weight'), message=\"ETP primitive etp_mv (weight=('out', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.\"))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from braintrace import DiagnosticKind, DiagnosticLevel\n",
    "\n",
    "print(f\"Diagnostics: {len(graph.diagnostics)}\")\n",
    "for d in graph.diagnostics:\n",
    "    print(f\"  [{d.level.name:7s}] {d.kind.name}: {d.message}\")\n",
    "\n",
    "# Common queries -- did anything error? was any weight excluded as a tail boundary?\n",
    "errors = [d for d in graph.diagnostics if d.level == DiagnosticLevel.ERROR]\n",
    "weight_to_weight = [\n",
    "    d for d in graph.diagnostics\n",
    "    if d.kind == DiagnosticKind.RELATION_EXCLUDED_WEIGHT_TO_WEIGHT\n",
    "]\n",
    "print(f\"\\nerrors: {len(errors)}, weight->weight exclusions: {len(weight_to_weight)}\")\n",
    "\n",
    "# `graph.explain()` prints the same information grouped by kind for quick scanning.\n",
    "graph.explain()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1d2e3f4",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "The braintrace compiler follows a 4-step pipeline to transform a neural network module into an optimized graph for online learning. The same pipeline produces a `CompilationRecord` you can inspect via `graph.diagnostics` (or `graph.explain()`) to debug missing or mis-placed relations.\n",
    "\n",
    "| Step | Function | Output | Purpose |\n",
    "|------|----------|--------|---------|\n",
    "| 1 | `extract_module_info` | `ModuleInfo` | Trace model, extract Jaxpr, classify states |\n",
    "| 2 | `find_hidden_groups_from_minfo` | `List[HiddenGroup]` | Identify connected recurrent state groups |\n",
    "| 3 | `find_hidden_param_op_relations_from_minfo` | `List[HiddenParamOpRelation]` | Connect ETP primitives to weights and hidden states |\n",
    "| 4 | `add_hidden_perturbation_from_minfo` | `HiddenPerturbation` | Build perturbation Jaxpr for Jacobian computation |\n",
    "\n",
    "**Key design decisions:**\n",
    "\n",
    "- **Type-based primitive identification** (`eqn.primitive in ETP_PRIMITIVES`) is robust and extensible, replacing the old fragile string-matching approach.\n",
    "- **Backward tracing** from weight variables to `ParamState` handles weight transformations transparently.\n",
    "- **Forward BFS** from primitive outputs to hidden outvars with shape compatibility filtering ensures correct connectivity.\n",
    "- **Perturbation rewriting** of the Jaxpr enables efficient hidden-to-hidden Jacobian computation via automatic differentiation.\n",
    "\n",
    "Understanding this pipeline is essential for debugging compilation failures, extending braintrace with new primitives, and reasoning about the structure of online learning graphs."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
