{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60001",
   "metadata": {},
   "source": [
    "# RNN Online Learning with BrainTrace\n",
    "\n",
    "**Train a GRU network on the copying task using D-RTRL**\n",
    "\n",
    "This quickstart tutorial demonstrates how to train a Gated Recurrent Unit (GRU) network using online learning with `braintrace`. We will:\n",
    "\n",
    "1. Define the **copying task**, a standard benchmark for testing sequential memory in RNNs.\n",
    "2. Build a GRU model using `braintrace.nn` components.\n",
    "3. Train the model with **D-RTRL** (Decoupled Real-Time Recurrent Learning), an online learning algorithm that computes approximate gradients without storing the full computation graph.\n",
    "4. Compare the online learning approach with standard **Backpropagation Through Time (BPTT)**.\n",
    "\n",
    "Online learning is especially useful when:\n",
    "- Memory is limited and storing the full unrolled computation graph is prohibitive.\n",
    "- You need to update parameters on-the-fly as data arrives.\n",
    "- You want biologically plausible learning rules for recurrent networks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60002",
   "metadata": {},
   "source": [
    "## 1. Setup\n",
    "\n",
    "First, we import the required packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1b2c3d4e5f60003",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:51.589484Z",
     "iopub.status.busy": "2026-04-17T09:24:51.589226Z",
     "iopub.status.idle": "2026-04-17T09:24:53.703161Z",
     "shell.execute_reply": "2026-04-17T09:24:53.702000Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintools\n",
    "import braintrace\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60004",
   "metadata": {},
   "source": [
    "## 2. The Copying Task\n",
    "\n",
    "The copying task is a classic benchmark for evaluating whether an RNN can memorize and recall information over a delay period.\n",
    "\n",
    "**How it works:**\n",
    "\n",
    "1. The model receives a sequence of 10 random digits (values 1-8) encoded as one-hot vectors.\n",
    "2. This is followed by a delay period filled with zeros (the \"wait\" phase).\n",
    "3. A special trigger symbol (value 9) signals the model to reproduce the original 10 digits.\n",
    "\n",
    "```\n",
    "Input:    [3 7 1 5 2 8 4 6 1 3] [0 0 ... 0 0] [9 9 9 9 9 9 9 9 9 9]\n",
    "                 memorize          wait/delay         recall trigger\n",
    "\n",
    "Target:   [3 7 1 5 2 8 4 6 1 3]\n",
    "```\n",
    "\n",
    "The longer the delay (`time_lag`), the harder the task. The model must retain information in its hidden state across the entire delay period."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a1b2c3d4e5f60005",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:53.706971Z",
     "iopub.status.busy": "2026-04-17T09:24:53.706416Z",
     "iopub.status.idle": "2026-04-17T09:24:53.711918Z",
     "shell.execute_reply": "2026-04-17T09:24:53.710993Z"
    }
   },
   "outputs": [],
   "source": [
    "class CopyDataset:\n",
    "    \"\"\"Data generator for the copying task.\n",
    "\n",
    "    Args:\n",
    "        time_lag: Number of delay steps between memorization and recall.\n",
    "        batch_size: Number of samples per batch.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, time_lag: int, batch_size: int):\n",
    "        self.seq_length = time_lag + 20\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def __iter__(self):\n",
    "        while True:\n",
    "            ids = np.zeros([self.batch_size, self.seq_length], dtype=int)\n",
    "            # First 10 positions: random digits 1-8\n",
    "            ids[..., :10] = np.random.randint(1, 9, (self.batch_size, 10))\n",
    "            # Last 10 positions: trigger symbol (9)\n",
    "            ids[..., -10:] = np.ones([self.batch_size, 10], dtype=int) * 9\n",
    "            # One-hot encode the input sequence\n",
    "            x = np.zeros([self.batch_size, self.seq_length, 10])\n",
    "            for i in range(self.batch_size):\n",
    "                x[i, range(self.seq_length), ids[i]] = 1\n",
    "            # Target: the original 10 digits to recall\n",
    "            yield x, ids[..., :10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60006",
   "metadata": {},
   "source": [
    "## 3. Model Definition\n",
    "\n",
    "We define a GRU network using `braintrace.nn.GRUCell` for the recurrent layer and `braintrace.nn.Linear` for the output layer. These modules are designed to work with `braintrace`'s online learning algorithms -- they expose the internal structure needed for eligibility trace computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1b2c3d4e5f60007",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:53.714001Z",
     "iopub.status.busy": "2026-04-17T09:24:53.713764Z",
     "iopub.status.idle": "2026-04-17T09:24:53.718119Z",
     "shell.execute_reply": "2026-04-17T09:24:53.717335Z"
    }
   },
   "outputs": [],
   "source": [
    "class GRUNet(brainstate.nn.Module):\n",
    "    \"\"\"A multi-layer GRU network with a linear readout.\n",
    "\n",
    "    Args:\n",
    "        n_in: Input feature dimension.\n",
    "        n_rec: Hidden state dimension.\n",
    "        n_out: Output dimension.\n",
    "        n_layer: Number of stacked GRU layers.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_in, n_rec, n_out, n_layer=1):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        for _ in range(n_layer):\n",
    "            layers.append(braintrace.nn.GRUCell(n_in, n_rec))\n",
    "            n_in = n_rec\n",
    "        self.rnn = brainstate.nn.Sequential(*layers)\n",
    "        self.readout = braintrace.nn.Linear(n_rec, n_out)\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.readout(self.rnn(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60008",
   "metadata": {},
   "source": [
    "## 4. Online Training with D-RTRL\n",
    "\n",
    "D-RTRL (Decoupled Real-Time Recurrent Learning) is an online learning algorithm provided by `braintrace`. Unlike BPTT, which requires storing the entire computation graph across all time steps, D-RTRL computes gradients incrementally at each time step using **eligibility traces**.\n",
    "\n",
    "The key steps in the online training loop are:\n",
    "\n",
    "1. **Initialize states**: Reset the model's hidden states and compile the eligibility trace graph.\n",
    "2. **Warm-up phase**: Run the model forward (without learning) to let the hidden states and eligibility traces stabilize.\n",
    "3. **Learning phase**: At each time step, compute the gradient of the current loss with respect to the parameters, and accumulate gradients over time.\n",
    "4. **Parameter update**: After processing the full sequence, apply the accumulated gradients to update the parameters.\n",
    "\n",
    "The `D_RTRL` class wraps the model and handles the eligibility trace bookkeeping automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a1b2c3d4e5f60009",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:53.720262Z",
     "iopub.status.busy": "2026-04-17T09:24:53.720048Z",
     "iopub.status.idle": "2026-04-17T09:24:53.726750Z",
     "shell.execute_reply": "2026-04-17T09:24:53.725676Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_online(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3):\n",
    "    \"\"\"Train a GRU on the copying task using D-RTRL online learning.\n",
    "\n",
    "    Args:\n",
    "        n_epochs: Number of training iterations.\n",
    "        n_seq: Length of the delay period in the copying task.\n",
    "        lr: Learning rate for the Adam optimizer.\n",
    "\n",
    "    Returns:\n",
    "        List of loss values over training.\n",
    "    \"\"\"\n",
    "    model = GRUNet(10, 128, 10)\n",
    "    opt = braintools.optim.Adam(lr)\n",
    "    weights = model.states().subset(brainstate.ParamState)\n",
    "    opt.register_trainable_weights(weights)\n",
    "\n",
    "    @brainstate.transform.jit\n",
    "    def train_step(inputs, targets):\n",
    "        # Create the online learning algorithm wrapper\n",
    "        algo = braintrace.D_RTRL(model)\n",
    "\n",
    "        # Initialize hidden states and compile eligibility trace graph\n",
    "        # using vmap to handle the batch dimension\n",
    "        @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])\n",
    "        def init():\n",
    "            brainstate.nn.init_all_states(model)\n",
    "            algo.compile_graph(inputs[0, 0])\n",
    "\n",
    "        init()\n",
    "        algo = brainstate.nn.Vmap(algo, vmap_states='new')\n",
    "\n",
    "        def etrace_loss(inp, tar):\n",
    "            out = algo(inp)\n",
    "            loss = braintools.metric.softmax_cross_entropy_with_integer_labels(out, tar).mean()\n",
    "            return loss, out\n",
    "\n",
    "        def step(prev_grads, x):\n",
    "            inp, tar = x\n",
    "            f_grad = brainstate.transform.grad(\n",
    "                etrace_loss, weights, has_aux=True, return_value=True\n",
    "            )\n",
    "            cur_grads, loss, out = f_grad(inp, tar)\n",
    "            next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)\n",
    "            return next_grads, loss\n",
    "\n",
    "        # Warm-up: run the model forward to stabilize hidden states\n",
    "        # and eligibility traces before computing learning gradients\n",
    "        n_sim = n_seq + 10\n",
    "        brainstate.transform.for_loop(lambda inp: algo(inp), inputs[:n_sim])\n",
    "\n",
    "        # Learning phase: accumulate gradients over the recall period\n",
    "        grads = jax.tree.map(jnp.zeros_like, {k: v.value for k, v in weights.items()})\n",
    "        grads, losses = brainstate.transform.scan(step, grads, (inputs[n_sim:], targets))\n",
    "        opt.update(grads)\n",
    "        return losses.mean()\n",
    "\n",
    "    # Training loop\n",
    "    dataloader = CopyDataset(n_seq, batch_size)\n",
    "    losses = []\n",
    "    for i, (x, y) in enumerate(dataloader):\n",
    "        if i >= n_epochs:\n",
    "            break\n",
    "        # Transpose from (batch, time, features) to (time, batch, features)\n",
    "        x = jnp.asarray(np.transpose(x, (1, 0, 2)))\n",
    "        y = jnp.asarray(np.transpose(y, (1, 0)))\n",
    "        loss = train_step(x, y)\n",
    "        losses.append(float(loss))\n",
    "        if i % 50 == 0:\n",
    "            print(f\"Step {i}, Loss: {loss:.4f}\")\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60010",
   "metadata": {},
   "source": [
    "## 5. BPTT Baseline (for Comparison)\n",
    "\n",
    "To appreciate the advantages of online learning, we also implement a standard BPTT trainer. BPTT unrolls the full computation graph across all time steps, computes the loss, and backpropagates through the entire sequence. This requires storing all intermediate activations, resulting in memory usage that scales linearly with sequence length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a1b2c3d4e5f60011",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:53.729133Z",
     "iopub.status.busy": "2026-04-17T09:24:53.728882Z",
     "iopub.status.idle": "2026-04-17T09:24:53.736663Z",
     "shell.execute_reply": "2026-04-17T09:24:53.735899Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_bptt(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3):\n",
    "    \"\"\"Train a GRU on the copying task using BPTT.\n",
    "\n",
    "    Args:\n",
    "        n_epochs: Number of training iterations.\n",
    "        n_seq: Length of the delay period in the copying task.\n",
    "        lr: Learning rate for the Adam optimizer.\n",
    "\n",
    "    Returns:\n",
    "        List of loss values over training.\n",
    "    \"\"\"\n",
    "    model = GRUNet(10, 128, 10)\n",
    "    opt = braintools.optim.Adam(lr)\n",
    "    weights = model.states().subset(brainstate.ParamState)\n",
    "    opt.register_trainable_weights(weights)\n",
    "\n",
    "    @brainstate.transform.jit\n",
    "    def train_step(inputs, targets):\n",
    "        # Initialize hidden states with vmap for batch dimension\n",
    "        @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])\n",
    "        def init():\n",
    "            brainstate.nn.init_all_states(model)\n",
    "\n",
    "        init()\n",
    "        vmapped_model = brainstate.nn.Vmap(model, vmap_states='new')\n",
    "\n",
    "        def run_step(inp, tar):\n",
    "            out = vmapped_model(inp)\n",
    "            loss = braintools.metric.softmax_cross_entropy_with_integer_labels(out, tar).mean()\n",
    "            return out, loss\n",
    "\n",
    "        def bptt_forward():\n",
    "            # Warm-up: run forward without computing loss\n",
    "            n_sim = n_seq + 10\n",
    "            brainstate.transform.for_loop(vmapped_model, inputs[:n_sim])\n",
    "            # Compute loss over the recall period\n",
    "            outs, losses = brainstate.transform.for_loop(run_step, inputs[n_sim:], targets)\n",
    "            return losses.mean(), outs\n",
    "\n",
    "        # Backpropagate through time to get gradients\n",
    "        grads, loss, outs = brainstate.transform.grad(\n",
    "            bptt_forward, weights, has_aux=True, return_value=True\n",
    "        )()\n",
    "        opt.update(grads)\n",
    "        return loss\n",
    "\n",
    "    # Training loop\n",
    "    dataloader = CopyDataset(n_seq, batch_size)\n",
    "    losses = []\n",
    "    for i, (x, y) in enumerate(dataloader):\n",
    "        if i >= n_epochs:\n",
    "            break\n",
    "        x = jnp.asarray(np.transpose(x, (1, 0, 2)))\n",
    "        y = jnp.asarray(np.transpose(y, (1, 0)))\n",
    "        loss = train_step(x, y)\n",
    "        losses.append(float(loss))\n",
    "        if i % 50 == 0:\n",
    "            print(f\"Step {i}, Loss: {loss:.4f}\")\n",
    "    return losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60012",
   "metadata": {},
   "source": [
    "## 6. Run Training\n",
    "\n",
    "Let us train both the online (D-RTRL) and offline (BPTT) models. We use a delay of 50 time steps, which is a moderate difficulty for the copying task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a1b2c3d4e5f60013",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:53.738559Z",
     "iopub.status.busy": "2026-04-17T09:24:53.738362Z",
     "iopub.status.idle": "2026-04-17T09:25:49.499504Z",
     "shell.execute_reply": "2026-04-17T09:25:49.498333Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('rnn', 'layers', 0, 'Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mv). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.\n",
      "  _emit_no_relation_diag(\n",
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('readout', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.\n",
      "  _emit_no_relation_diag(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0, Loss: 2.2919\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 50, Loss: 2.0887\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 100, Loss: 2.0790\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 150, Loss: 2.0802\n"
     ]
    }
   ],
   "source": [
    "online_losses = train_online(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a1b2c3d4e5f60014",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:25:49.503534Z",
     "iopub.status.busy": "2026-04-17T09:25:49.503131Z",
     "iopub.status.idle": "2026-04-17T09:25:55.356660Z",
     "shell.execute_reply": "2026-04-17T09:25:55.355570Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0, Loss: 2.2838\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 50, Loss: 2.0851\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 100, Loss: 2.0819\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 150, Loss: 2.0830\n"
     ]
    }
   ],
   "source": [
    "bptt_losses = train_bptt(n_epochs=200, n_seq=50, batch_size=64, lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60015",
   "metadata": {},
   "source": [
    "## 7. Visualization\n",
    "\n",
    "Plot the training loss curves to compare online learning (D-RTRL) with BPTT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a1b2c3d4e5f60016",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:25:55.359740Z",
     "iopub.status.busy": "2026-04-17T09:25:55.359410Z",
     "iopub.status.idle": "2026-04-17T09:25:55.407215Z",
     "shell.execute_reply": "2026-04-17T09:25:55.406483Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_283514/382461402.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n"
     ]
    }
   ],
   "source": [
    "plt.figure(figsize=(8, 4))\n",
    "plt.plot(online_losses, label='D-RTRL (Online)')\n",
    "plt.plot(bptt_losses, label='BPTT (Offline)')\n",
    "plt.xlabel('Training Step')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('GRU on Copying Task: Online vs. Offline Learning')\n",
    "plt.legend()\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60017",
   "metadata": {},
   "source": [
    "## 8. Summary\n",
    "\n",
    "In this tutorial, we demonstrated how to use `braintrace` for online learning of a GRU network on the copying task.\n",
    "\n",
    "**Key takeaways:**\n",
    "\n",
    "- **D-RTRL** provides approximate online gradients with `O(B * theta)` complexity, where `B` is the batch size and `theta` is the number of parameters. Unlike BPTT, it does not need to store the full unrolled computation graph.\n",
    "- The online training loop uses `braintrace.D_RTRL` to automatically manage eligibility traces. You only need to:\n",
    "  1. Wrap your model with `D_RTRL`.\n",
    "  2. Call `compile_graph()` to set up the trace computation.\n",
    "  3. Use standard `brainstate.transform.grad` to compute per-step gradients.\n",
    "- Online learning works with standard JAX gradient APIs and transformations (`jit`, `vmap`, `scan`).\n",
    "- `braintrace` is particularly effective for RNN models with gating mechanisms (GRU, LSTM), where the internal dynamics naturally support eligibility trace propagation.\n",
    "\n",
    "For more details, see:\n",
    "- [Key Concepts](./concepts-en.ipynb) for the theoretical background.\n",
    "- [SNN Online Learning](./snn_online_learning-en.ipynb) for applying the same approach to spiking neural networks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
