{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9f854a9",
   "metadata": {},
   "source": [
    "# Custom Algorithm Development\n",
    "\n",
    "This advanced tutorial covers how to develop custom online learning algorithms with braintrace.\n",
    "\n",
    "braintrace ships with two built-in algorithms:\n",
    "\n",
    "- **`D_RTRL`** (`ParamDimVjpAlgorithm`): Full eligibility traces with parameter-dimension complexity.\n",
    "- **`ES_D_RTRL`** (`IODimVjpAlgorithm`): Factorized eligibility traces with input/output-dimension complexity.\n",
    "\n",
    "For research purposes, you may want to implement custom online learning algorithms — for example, adding trace clipping, spectral normalization of Jacobians, or entirely different trace update rules. braintrace's algorithm hierarchy makes this straightforward.\n",
    "\n",
    "**The algorithm class hierarchy:**\n",
    "\n",
    "```\n",
    "ETraceAlgorithm              # Base: model wrapping, graph compilation, state separation\n",
    "  └─ ETraceVjpAlgorithm       # Adds VJP-based gradient computation, implements update()\n",
    "       ├─ ParamDimVjpAlgorithm # D-RTRL: traces stored per weight parameter\n",
    "       └─ IODimVjpAlgorithm    # ES-D-RTRL: traces factorized into input/output components\n",
    "```\n",
    "\n",
    "All algorithms share the same graph compilation infrastructure: `ModuleInfo`, `HiddenGroup`, `HiddenPerturbation`, and `ETraceGraph`. You only need to customize the trace update and gradient computation logic."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95c75af3",
   "metadata": {},
   "source": [
    "## Algorithm Architecture\n",
    "\n",
    "### `ETraceAlgorithm` (Base Class)\n",
    "\n",
    "The root of the hierarchy. It handles:\n",
    "\n",
    "- **Model wrapping**: Stores the target `brainstate.nn.Module` and its graph executor.\n",
    "- **Graph compilation**: Calls `compile_graph(*args)` to build the eligibility trace computation graph.\n",
    "- **State separation**: Splits model states into `param_states` (weights), `hidden_states` (recurrent states), and `other_states`.\n",
    "- **Running index**: Tracks the current time step via `self.running_index`.\n",
    "\n",
    "Key abstract methods: `init_etrace_state()`, `update()`, `get_etrace_of()`.\n",
    "\n",
    "### `ETraceVjpAlgorithm`\n",
    "\n",
    "Extends `ETraceAlgorithm` with VJP-based (reverse-mode) gradient computation. It:\n",
    "\n",
    "- Defines a `custom_vjp` function that wraps the forward pass and eligibility trace update.\n",
    "- Implements the `update()` method, which extracts state values, calls the forward+trace update, and writes states back.\n",
    "- Provides the backward pass (`_update_fn_bwd`) that computes weight gradients from eligibility traces and loss gradients.\n",
    "\n",
    "Key protocol methods (to be overridden by subclasses):\n",
    "- `_update_etrace_data()`: Core trace update logic.\n",
    "- `_solve_weight_gradients()`: Compute final weight gradients from traces and loss gradients.\n",
    "- `_get_etrace_data()`: Retrieve current trace values from states.\n",
    "- `_assign_etrace_data()`: Write trace values back to states.\n",
    "\n",
    "### `ParamDimVjpAlgorithm` (D-RTRL)\n",
    "\n",
    "Stores one eligibility trace tensor per (weight, hidden group) pair. The trace has the same shape as the weight parameter (times the number of hidden states). Memory complexity: $O(B\\theta)$ where $\\theta$ is the parameter count and $B$ the batch size.\n",
    "\n",
    "### `IODimVjpAlgorithm` (ES-D-RTRL)\n",
    "\n",
    "Factorizes the eligibility trace into separate input traces ($\\epsilon_x$) and output/transition traces ($\\epsilon_f$). This reduces memory to $O(B(I+O))$ where $I$ and $O$ are the input and output dimensions. Controlled by a `decay_or_rank` parameter."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab0aa9a",
   "metadata": {},
   "source": [
    "## D-RTRL Mathematical Foundation\n",
    "\n",
    "The D-RTRL algorithm maintains an eligibility trace $\\epsilon^t$ that approximates the full Jacobian $\\partial h^t / \\partial \\theta$ using a diagonal approximation of the hidden-to-hidden Jacobian:\n",
    "\n",
    "$$\n",
    "\\epsilon^t \\approx D^t \\, \\epsilon^{t-1} + \\text{diag}(D_f^t) \\otimes x^t\n",
    "$$\n",
    "\n",
    "where:\n",
    "\n",
    "- $D^t = \\text{diag}(\\partial h^t / \\partial h^{t-1})$: the hidden-to-hidden Jacobian (diagonal approximation)\n",
    "- $D_f^t = \\partial h^t / \\partial y^t$: the transition Jacobian, where $y^t$ is the output of the weight operation\n",
    "- $x^t$: the input to the weight operation at time $t$\n",
    "- $\\otimes$: the outer product\n",
    "\n",
    "The weight gradient is then computed by combining the eligibility traces with the loss gradient:\n",
    "\n",
    "$$\n",
    "\\nabla_\\theta \\mathcal{L} = \\sum_{t' \\in \\mathcal{T}} \\frac{\\partial \\mathcal{L}^{t'}}{\\partial h^{t'}} \\circ \\epsilon^{t'}\n",
    "$$\n",
    "\n",
    "where $\\circ$ denotes the contraction over hidden dimensions that produces the weight-shaped gradient."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55cabf64",
   "metadata": {},
   "source": [
    "## ES-D-RTRL Mathematical Foundation\n",
    "\n",
    "ES-D-RTRL further approximates the D-RTRL trace by factorizing it into separate input and output components:\n",
    "\n",
    "$$\n",
    "\\epsilon^t \\approx \\epsilon_f^t \\otimes \\epsilon_x^t\n",
    "$$\n",
    "\n",
    "The two components are updated with exponential smoothing controlled by a decay factor $\\alpha$:\n",
    "\n",
    "**Input trace:**\n",
    "$$\n",
    "\\epsilon_x^t = \\alpha \\, \\epsilon_x^{t-1} + x^t\n",
    "$$\n",
    "\n",
    "**Output trace:**\n",
    "$$\n",
    "\\epsilon_f^t = \\alpha \\, D^t \\circ \\epsilon_f^{t-1} + (1 - \\alpha) \\, D_f^t\n",
    "$$\n",
    "\n",
    "where:\n",
    "\n",
    "- $\\alpha \\in (0, 1)$: exponential smoothing decay factor\n",
    "- $D^t$: hidden-to-hidden diagonal Jacobian (same as in D-RTRL)\n",
    "- $D_f^t$: transition Jacobian\n",
    "- $x^t$: input to the weight operation\n",
    "\n",
    "The decay factor $\\alpha$ is controlled by the `decay_or_rank` parameter:\n",
    "- If a `float` in $(0, 1)$: used directly as $\\alpha$.\n",
    "- If an `int` $> 0$ (the approximation rank): $\\alpha = (\\text{rank} - 1) / (\\text{rank} + 1)$.\n",
    "\n",
    "The weight gradient formula is the same as D-RTRL, but uses the factorized traces."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f54aa236",
   "metadata": {},
   "source": [
    "## Key Methods to Override\n",
    "\n",
    "When implementing a custom algorithm, you typically subclass `ParamDimVjpAlgorithm` or `IODimVjpAlgorithm` and override one or more of these methods:\n",
    "\n",
    "| Method | Purpose | Defined in |\n",
    "|--------|---------|------------|\n",
    "| `init_etrace_state(*args)` | Initialize trace storage (called during `compile_graph`) | `ETraceAlgorithm` |\n",
    "| `_get_etrace_data()` | Retrieve current trace values from internal states | `ETraceVjpAlgorithm` |\n",
    "| `_assign_etrace_data(vals)` | Write trace values back to internal states | `ETraceVjpAlgorithm` |\n",
    "| `_update_etrace_data(...)` | Core trace update logic (the trace recurrence equation) | `ETraceVjpAlgorithm` |\n",
    "| `_solve_weight_gradients(...)` | Compute final weight gradients from traces + loss gradients | `ETraceVjpAlgorithm` |\n",
    "| `reset_state(batch_size)` | Reset traces between epochs/episodes | `ParamDimVjpAlgorithm` / `IODimVjpAlgorithm` |\n",
    "\n",
    "### Method signatures\n",
    "\n",
    "```python\n",
    "def _update_etrace_data(\n",
    "    self,\n",
    "    running_index,            # int: current time step\n",
    "    etrace_vals_util_t_1,     # ETraceVals: traces accumulated until t-1\n",
    "    hid2weight_jac,           # Hid2WeightJacobian: current Jacobians\n",
    "    hid2hid_jac,              # Sequence[jax.Array]: hidden-to-hidden Jacobians\n",
    "    weight_vals,              # Dict[Path, PyTree]: current weight values\n",
    "    input_is_multi_step,      # bool: whether input spans multiple steps\n",
    ") -> ETraceVals:\n",
    "    ...\n",
    "\n",
    "def _solve_weight_gradients(\n",
    "    self,\n",
    "    running_index,            # int: current time step\n",
    "    etrace_h2w_at_t,          # eligibility trace data at time t\n",
    "    dl_to_hidden_groups,      # Sequence[jax.Array]: dL/dh per hidden group\n",
    "    weight_vals,              # Dict[WeightID, PyTree]: current weight values\n",
    "    dl_to_nonetws_at_t,       # gradients of non-etrace parameters\n",
    "    dl_to_etws_at_t,          # optional gradients of etrace parameters\n",
    ") -> Dict[Path, PyTree]:\n",
    "    ...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5080907a",
   "metadata": {},
   "source": [
    "## Example: Implementing a Clipped D-RTRL\n",
    "\n",
    "A common issue with eligibility traces in deep recurrent networks is trace explosion: the trace magnitudes grow unboundedly over time. One practical mitigation is to clip the trace values after each update.\n",
    "\n",
    "Below we implement `ClippedDRTRL`, which inherits from `ParamDimVjpAlgorithm` and applies element-wise clipping to the updated traces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1d7b6e30",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:18.687036Z",
     "iopub.status.busy": "2026-04-17T09:26:18.686778Z",
     "iopub.status.idle": "2026-04-17T09:26:18.305661Z",
     "shell.execute_reply": "2026-04-17T09:26:18.304851Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace\n",
    "from braintrace._etrace_algorithms.d_rtrl import ParamDimVjpAlgorithm\n",
    "\n",
    "\n",
    "class ClippedDRTRL(ParamDimVjpAlgorithm):\n",
    "    \"\"\"D-RTRL with trace clipping for stability.\n",
    "\n",
    "    After each trace update, all trace values are clipped to\n",
    "    [-clip_value, +clip_value] to prevent trace explosion in\n",
    "    deep or long-horizon recurrent networks.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, model, clip_value=1.0, **kwargs):\n",
    "        super().__init__(model, **kwargs)\n",
    "        self.clip_value = clip_value\n",
    "\n",
    "    def _update_etrace_data(\n",
    "        self,\n",
    "        running_index,\n",
    "        hist_etrace_vals,\n",
    "        hid2weight_jac,\n",
    "        hid2hid_jac,\n",
    "        weight_vals,\n",
    "        input_is_multi_step,\n",
    "    ):\n",
    "        # Call parent's trace update (standard D-RTRL recurrence)\n",
    "        new_traces = super()._update_etrace_data(\n",
    "            running_index,\n",
    "            hist_etrace_vals,\n",
    "            hid2weight_jac,\n",
    "            hid2hid_jac,\n",
    "            weight_vals,\n",
    "            input_is_multi_step,\n",
    "        )\n",
    "        # Clip the traces element-wise\n",
    "        new_traces = jax.tree.map(\n",
    "            lambda t: jnp.clip(t, -self.clip_value, self.clip_value),\n",
    "            new_traces,\n",
    "        )\n",
    "        return new_traces"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01bab22c",
   "metadata": {},
   "source": [
    "Now we can use `ClippedDRTRL` exactly like the built-in `D_RTRL`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dfb10cee",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:18.308154Z",
     "iopub.status.busy": "2026-04-17T09:26:18.307804Z",
     "iopub.status.idle": "2026-04-17T09:26:18.741655Z",
     "shell.execute_reply": "2026-04-17T09:26:18.740751Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClippedDRTRL compiled successfully.\n",
      "Number of param states: 1\n",
      "Number of hidden states: 1\n"
     ]
    }
   ],
   "source": [
    "# Create a simple RNN model\n",
    "model = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)\n",
    "brainstate.nn.init_all_states(model)\n",
    "\n",
    "# Instantiate our custom algorithm with clip_value=5.0\n",
    "algo = ClippedDRTRL(model, clip_value=5.0)\n",
    "\n",
    "# Compile the computation graph with a dummy input\n",
    "algo.compile_graph(jnp.zeros(10))\n",
    "\n",
    "print(\"ClippedDRTRL compiled successfully.\")\n",
    "print(f\"Number of param states: {len(algo.param_states)}\")\n",
    "print(f\"Number of hidden states: {len(algo.hidden_states)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "16bf14fe",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:18.743542Z",
     "iopub.status.busy": "2026-04-17T09:26:18.743337Z",
     "iopub.status.idle": "2026-04-17T09:26:20.747525Z",
     "shell.execute_reply": "2026-04-17T09:26:20.746733Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0: output shape = (32,)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1: output shape = (32,)\n",
      "Step 2: output shape = (32,)\n",
      "Step 3: output shape = (32,)\n",
      "Step 4: output shape = (32,)\n"
     ]
    }
   ],
   "source": [
    "# Run a few forward steps\n",
    "for step in range(5):\n",
    "    out = algo(jnp.ones(10))\n",
    "    print(f\"Step {step}: output shape = {out.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "263909d1",
   "metadata": {},
   "source": [
    "## Understanding the Update Flow\n",
    "\n",
    "At each time step, the algorithm performs the following sequence:\n",
    "\n",
    "1. **Forward pass**: The graph executor runs the model to produce the output, new hidden states, and new other states.\n",
    "\n",
    "2. **Jacobian computation**: The executor also computes:\n",
    "   - **h2w Jacobians**: $\\partial h^t / \\partial y^t$ and $x^t$ (how the hidden state changes w.r.t. the weight operation output and input).\n",
    "   - **h2h Jacobians**: $\\partial h^t / \\partial h^{t-1}$ (the recurrent Jacobian, diagonal approximation).\n",
    "\n",
    "3. **Trace update**: `_update_etrace_data()` uses the Jacobians and the previous traces to compute the new eligibility traces via the chosen recurrence equation.\n",
    "\n",
    "4. **Gradient computation (on backward pass)**: When `jax.grad` or `brainstate.transform.grad` is applied, the `custom_vjp` backward pass calls `_solve_weight_gradients()` to combine the traces with the loss-to-hidden gradient $\\partial L / \\partial h^t$ to produce parameter gradients.\n",
    "\n",
    "The key insight is that `ETraceVjpAlgorithm.update()` wraps the forward pass + trace update inside a `jax.custom_vjp` function. This means:\n",
    "- **Forward**: The model runs normally, and traces are updated as a side effect.\n",
    "- **Backward**: Instead of backpropagating through the entire recurrent history (as BPTT would), the custom VJP uses the eligibility traces to directly produce weight gradients at the current time step.\n",
    "\n",
    "This is what makes online learning possible: gradients are computed on-the-fly without storing the full computation history.\n",
    "\n",
    "```\n",
    "                      forward pass\n",
    "                     ┌─────────────────────────────────────────┐\n",
    "                     │                                         │\n",
    "  x^t ────────▶ graph_executor.solve_h2w_h2h_jacobian()   │\n",
    "  h^{t-1} ───▶     │  ──▶ output, h^t, h2w_jac, h2h_jac       │\n",
    "  weights ────▶     │                                         │\n",
    "                     └─────────────────┬───────────────────────┘\n",
    "                                     │\n",
    "                                     ▼\n",
    "                      trace update\n",
    "                     ┌─────────────────────────────────────────┐\n",
    "  ε^{t-1} ────▶     │                                         │\n",
    "  h2w_jac ────▶ _update_etrace_data()  ──▶ ε^t           │\n",
    "  h2h_jac ────▶     │                                         │\n",
    "                     └─────────────────────────────────────────┘\n",
    "                                     │\n",
    "                    (on backward)     ▼\n",
    "                     ┌─────────────────────────────────────────┐\n",
    "  dL/dh^t ────▶     │                                         │\n",
    "  ε^t ────────▶ _solve_weight_gradients()  ──▶ dL/dθ      │\n",
    "                     │                                         │\n",
    "                     └─────────────────────────────────────────┘\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12491994",
   "metadata": {},
   "source": [
    "## Accessing Eligibility Traces\n",
    "\n",
    "After running the algorithm for a few steps, you can inspect the eligibility traces for any weight parameter. This is useful for debugging and analysis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b733466c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:20.749823Z",
     "iopub.status.busy": "2026-04-17T09:26:20.749574Z",
     "iopub.status.idle": "2026-04-17T09:26:21.028615Z",
     "shell.execute_reply": "2026-04-17T09:26:21.027685Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight ('W', 'weight'), trace key: (('W', 'weight'), 130753575862208, 'hidden_group_0')\n",
      "  trace shape: (42, 32, 1)\n",
      "  trace abs max: 5.605252\n"
     ]
    }
   ],
   "source": [
    "# Create a fresh model and algorithm\n",
    "model2 = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)\n",
    "brainstate.nn.init_all_states(model2)\n",
    "\n",
    "algo2 = braintrace.D_RTRL(model2)\n",
    "algo2.compile_graph(jnp.zeros(10))\n",
    "\n",
    "# Run a few steps to build up traces\n",
    "for _ in range(5):\n",
    "    algo2(jnp.ones(10))\n",
    "\n",
    "# Inspect eligibility traces for each weight parameter\n",
    "weights = model2.states(brainstate.ParamState)\n",
    "for path, w in weights.items():\n",
    "    traces = algo2.get_etrace_of(w)\n",
    "    for key, trace_val in traces.items():\n",
    "        print(f\"Weight {path}, trace key: {key}\")\n",
    "        print(f\"  trace shape: {trace_val.shape}\")\n",
    "        print(f\"  trace abs max: {jnp.abs(trace_val).max():.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c9e392c",
   "metadata": {},
   "source": [
    "## Customising the Gradient Solve\n",
    "\n",
    "The `_update_etrace_data()` override above shapes how traces are *built up*. To shape how the final weight gradients are *read out* of those traces, override `_solve_weight_gradients()`. A common use is global gradient norm clipping so that one outlier weight cannot blow up the optimizer step.\n",
    "\n",
    "The override pattern is the same as before -- delegate to `super()._solve_weight_gradients(...)` for the standard contraction, then transform the dict it returns. The signature is:\n",
    "\n",
    "```python\n",
    "def _solve_weight_gradients(\n",
    "    self,\n",
    "    running_index: int,\n",
    "    etrace_h2w_at_t,                # dict[(weight_id, hidden_id), pytree]\n",
    "    dl_to_hidden_groups,            # Sequence[jax.Array]: dL/dh per group\n",
    "    weight_vals,                    # dict[weight_id, pytree]: current weights\n",
    "    dl_to_nonetws_at_t,             # dict[path, pytree]: non-ETP grads\n",
    "    dl_to_etws_at_t,                # Optional[dict[path, pytree]]: ETP shortcut grads\n",
    ") -> dict[Path, PyTree]:\n",
    "    ...\n",
    "```\n",
    "\n",
    "The returned dict maps every parameter path (ETP and non-ETP alike) to its gradient pytree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d4fec171",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:21.030999Z",
     "iopub.status.busy": "2026-04-17T09:26:21.030786Z",
     "iopub.status.idle": "2026-04-17T09:26:21.042842Z",
     "shell.execute_reply": "2026-04-17T09:26:21.041444Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GradClippedDRTRL compiled OK; max_norm = 2.0\n"
     ]
    }
   ],
   "source": [
    "class GradClippedDRTRL(braintrace.D_RTRL):\n",
    "    \"\"\"D-RTRL that clips the global gradient norm after solving.\"\"\"\n",
    "\n",
    "    def __init__(self, model, max_norm: float = 1.0, **kwargs):\n",
    "        super().__init__(model, **kwargs)\n",
    "        self.max_norm = max_norm\n",
    "\n",
    "    def _solve_weight_gradients(\n",
    "        self,\n",
    "        running_index,\n",
    "        etrace_h2w_at_t,\n",
    "        dl_to_hidden_groups,\n",
    "        weight_vals,\n",
    "        dl_to_nonetws_at_t,\n",
    "        dl_to_etws_at_t,\n",
    "    ):\n",
    "        # Standard contraction: traces x dL/dh -> per-weight gradients.\n",
    "        grads = super()._solve_weight_gradients(\n",
    "            running_index,\n",
    "            etrace_h2w_at_t,\n",
    "            dl_to_hidden_groups,\n",
    "            weight_vals,\n",
    "            dl_to_nonetws_at_t,\n",
    "            dl_to_etws_at_t,\n",
    "        )\n",
    "        # Compute global L2 norm across all leaves of the gradient dict.\n",
    "        sq = jax.tree.reduce(\n",
    "            lambda acc, g: acc + jnp.sum(g * g),\n",
    "            grads,\n",
    "            initializer=jnp.zeros((), dtype=jnp.float32),\n",
    "        )\n",
    "        norm = jnp.sqrt(sq)\n",
    "        scale = jnp.minimum(1.0, self.max_norm / (norm + 1e-12))\n",
    "        return jax.tree.map(lambda g: g * scale, grads)\n",
    "\n",
    "\n",
    "# Smoke-test that the new algorithm compiles and produces gradients.\n",
    "clipped_model = braintrace.nn.ValinaRNNCell(in_size=10, out_size=32)\n",
    "brainstate.nn.init_all_states(clipped_model)\n",
    "clipped_algo = GradClippedDRTRL(clipped_model, max_norm=2.0)\n",
    "clipped_algo.compile_graph(jnp.zeros(10))\n",
    "print(\"GradClippedDRTRL compiled OK; max_norm =\", clipped_algo.max_norm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa7cc82e",
   "metadata": {},
   "source": [
    "## Resetting Trace State Between Episodes\n",
    "\n",
    "Eligibility traces are stored on the algorithm as `EligibilityTrace` instances (a thin subclass of `brainstate.ShortTermState`). Between epochs, sequences, or evaluation runs you typically want to zero them so the next sequence starts from a clean state. Both `D_RTRL` and `ES_D_RTRL` expose `reset_state(batch_size=None)` for this.\n",
    "\n",
    "`reset_state` does two things: it resets the algorithm's `running_index` counter to 0, and it zeros every `EligibilityTrace` (re-broadcasting to the requested `batch_size` if given). It does **not** touch the model's hidden states -- call `brainstate.nn.reset_all_states(model)` for those. Override `reset_state` only when your custom algorithm carries extra state (e.g. a momentum accumulator) that must also be cleared."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "09573e15",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:26:21.046189Z",
     "iopub.status.busy": "2026-04-17T09:26:21.045921Z",
     "iopub.status.idle": "2026-04-17T09:26:21.204357Z",
     "shell.execute_reply": "2026-04-17T09:26:21.203518Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trace abs-max before reset: 2.913480\n",
      "Trace abs-max after  reset: 0.000000\n",
      "running_index after  reset: 0\n"
     ]
    }
   ],
   "source": [
    "# Run a few steps so the traces are non-zero, reset, then verify.\n",
    "for _ in range(3):\n",
    "    clipped_algo(jnp.ones(10))\n",
    "\n",
    "# Show that traces are non-zero before reset.\n",
    "weights = clipped_model.states(brainstate.ParamState)\n",
    "sample_weight = next(iter(weights.values()))\n",
    "sample_trace = next(iter(clipped_algo.get_etrace_of(sample_weight).values()))\n",
    "print(f\"Trace abs-max before reset: {jnp.abs(sample_trace).max():.6f}\")\n",
    "\n",
    "# Reset and verify the traces are now zero (and running_index == 0).\n",
    "clipped_algo.reset_state()\n",
    "sample_trace = next(iter(clipped_algo.get_etrace_of(sample_weight).values()))\n",
    "print(f\"Trace abs-max after  reset: {jnp.abs(sample_trace).max():.6f}\")\n",
    "print(f\"running_index after  reset: {int(clipped_algo.running_index.value)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e700441f",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This tutorial covered the architecture and extension points for developing custom online learning algorithms in braintrace.\n",
    "\n",
    "**Key takeaways:**\n",
    "\n",
    "- **Extend `ParamDimVjpAlgorithm`** for custom algorithms that maintain full-dimensional traces (like D-RTRL). Extend **`IODimVjpAlgorithm`** for custom algorithms that use factorized traces (like ES-D-RTRL).\n",
    "\n",
    "- **Override `_update_etrace_data()`** to implement custom trace dynamics (e.g., clipping, normalization, decay schedules).\n",
    "\n",
    "- **Override `_solve_weight_gradients()`** to transform the gradient dict produced by the standard contraction -- e.g. global gradient-norm clipping (`GradClippedDRTRL` above), per-layer scaling, or momentum.\n",
    "\n",
    "- **The graph compilation infrastructure is shared** across all algorithms. You do not need to re-implement the model tracing, Jacobian computation, or state management — only the trace update and gradient computation.\n",
    "\n",
    "- **`reset_state(batch_size=None)`** zeros every `EligibilityTrace` and resets `running_index` to 0; override it only when your algorithm carries extra state that must also be cleared. Override **`init_etrace_state()`** if your algorithm needs new trace-storage shapes at compile time.\n",
    "\n",
    "- Use **`get_etrace_of(weight)`** to inspect trace values at any point during training, which is valuable for debugging and research analysis."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
