{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f6a7b8",
   "metadata": {},
   "source": [
    "# Core Concepts of BrainTrace"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5f6a7b8c9",
   "metadata": {},
   "source": [
    "Welcome to **BrainTrace**! This notebook introduces the core concepts you need to understand before using the library for online learning in recurrent and spiking neural networks.\n",
    "\n",
    "BrainTrace is built on [JAX](https://github.com/google/jax) and [brainstate](https://brainstate.readthedocs.io/), providing memory-efficient online learning through eligibility trace propagation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3d4e5f6a7b8c9d0",
   "metadata": {},
   "source": [
    "## 1. What is Online Learning?\n",
    "\n",
    "Training recurrent neural networks (RNNs) typically relies on **Backpropagation Through Time (BPTT)**. BPTT unrolls the full computation graph over all time steps before computing gradients:\n",
    "\n",
    "- **BPTT** stores the entire computation graph across $T$ time steps, requiring $O(T)$ memory.\n",
    "- As sequence length grows, memory usage becomes a bottleneck.\n",
    "\n",
    "**Online learning** takes a different approach:\n",
    "\n",
    "- Weights are updated at **each time step** using **eligibility traces** that summarize the gradient history.\n",
    "- Memory cost is $O(1)$ per time step (independent of sequence length).\n",
    "- Eligibility traces accumulate the information needed for gradient computation incrementally.\n",
    "\n",
    "BrainTrace implements online learning via **JAX custom primitives**. Instead of relying on string-matching or special parameter wrappers, BrainTrace identifies which operations participate in online learning by their **primitive type** at the JAX IR level. This gives a clean, composable, and JIT-friendly design."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e5f6a7b8c9d0e1",
   "metadata": {},
   "source": [
    "## 2. Architecture Overview\n",
    "\n",
    "BrainTrace is organized as a 4-layer system. Each layer builds on the one below it:\n",
    "\n",
    "```\n",
    "+--------------------------------------------------------------+\n",
    "|  Algorithms    braintrace.D_RTRL / braintrace.ES_D_RTRL      |\n",
    "|                trace update + custom_vjp for jax.grad         |\n",
    "+--------------------------------------------------------------+\n",
    "|  Executor      ETraceGraphExecutor                            |\n",
    "|                forward pass + Jacobian computation            |\n",
    "+--------------------------------------------------------------+\n",
    "|  Compiler      compile_etrace_graph()                         |\n",
    "|                jaxpr walk -> find primitives -> connect to     |\n",
    "|                hidden states                                  |\n",
    "+--------------------------------------------------------------+\n",
    "|  Primitives    braintrace.matmul / element_wise / conv        |\n",
    "|  & Functions   JAX custom primitives (thin markers)           |\n",
    "+--------------------------------------------------------------+\n",
    "```\n",
    "\n",
    "**How it works:**\n",
    "\n",
    "1. **Primitives & Functions** (bottom layer): You call `braintrace.matmul(x, w)` in your model. Under the hood, this binds a JAX custom primitive that acts as a *marker* — the actual computation is standard JAX (`x @ w`).\n",
    "\n",
    "2. **Compiler**: When you call `compile_graph()`, BrainTrace walks the JAX intermediate representation (jaxpr), finds all ETP primitives, and connects each one to its associated hidden states and parameters.\n",
    "\n",
    "3. **Executor**: During the forward pass, the executor computes the model output *and* the Jacobians needed for eligibility trace updates.\n",
    "\n",
    "4. **Algorithms** (top layer): `D_RTRL` or `ES_D_RTRL` use the executor outputs to maintain eligibility traces and provide correct gradients via `custom_vjp`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f6a7b8c9d0e1f2",
   "metadata": {},
   "source": [
    "## 3. Key Concept: Primitive-Based Parameter Selection\n",
    "\n",
    "The central design idea of BrainTrace is that **the operation you use determines whether a parameter participates in online learning**:\n",
    "\n",
    "| What you write | Effect |\n",
    "|---|---|\n",
    "| `braintrace.matmul(x, w)` | `w` is **included** in online learning (eligibility traces are maintained) |\n",
    "| `x @ w` (regular JAX matmul) | `w` is **excluded** from online learning (only instantaneous gradients) |\n",
    "\n",
    "There is **no need for special parameter classes**. All weights are plain `brainstate.ParamState`. The choice of operation is what matters.\n",
    "\n",
    "Here is a concrete example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f6a7b8c9d0e1f2a3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:43.456851Z",
     "iopub.status.busy": "2026-04-17T09:24:43.456699Z",
     "iopub.status.idle": "2026-04-17T09:24:45.629050Z",
     "shell.execute_reply": "2026-04-17T09:24:45.628092Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a7b8c9d0e1f2a3b4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:45.631615Z",
     "iopub.status.busy": "2026-04-17T09:24:45.631220Z",
     "iopub.status.idle": "2026-04-17T09:24:45.635810Z",
     "shell.execute_reply": "2026-04-17T09:24:45.635234Z"
    }
   },
   "outputs": [],
   "source": [
    "class SimpleRNN(brainstate.nn.Module):\n",
    "    def __init__(self, n_in, n_rec, n_out):\n",
    "        super().__init__()\n",
    "        self.w_in = brainstate.ParamState(brainstate.random.randn(n_in, n_rec) * 0.01)\n",
    "        self.w_rec = brainstate.ParamState(brainstate.random.randn(n_rec, n_rec) * 0.01)\n",
    "        self.w_out = brainstate.ParamState(brainstate.random.randn(n_rec, n_out) * 0.01)\n",
    "        self.h = brainstate.ShortTermState(jnp.zeros(n_rec))\n",
    "\n",
    "    def update(self, x):\n",
    "        # Regular matmul: w_in excluded from online learning\n",
    "        inp = x @ self.w_in.value\n",
    "\n",
    "        # ETP matmul: w_rec included in online learning\n",
    "        rec = braintrace.matmul(self.h.value, self.w_rec.value)\n",
    "\n",
    "        self.h.value = jax.nn.tanh(inp + rec)\n",
    "\n",
    "        # Regular matmul: w_out excluded from online learning\n",
    "        return self.h.value @ self.w_out.value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c9d0e1f2a3b4c5",
   "metadata": {},
   "source": [
    "In the model above:\n",
    "\n",
    "- `w_in` and `w_out` use standard `x @ w` — they receive only instantaneous gradients (no temporal credit assignment through eligibility traces).\n",
    "- `w_rec` uses `braintrace.matmul(h, w_rec)` — the compiler will automatically maintain eligibility traces for this weight, enabling gradient computation that accounts for temporal dependencies.\n",
    "\n",
    "All three weights are the same type (`brainstate.ParamState`). The **operation** is the only difference."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9d0e1f2a3b4c5d6",
   "metadata": {},
   "source": [
    "## 4. Using `braintrace.nn` Modules\n",
    "\n",
    "While you can use primitives directly (as shown above), BrainTrace provides **pre-built layers** in the `braintrace.nn` module that already use ETP primitives internally. These are drop-in replacements for standard `brainstate.nn` layers:\n",
    "\n",
    "| Module | Description |\n",
    "|---|---|\n",
    "| `braintrace.nn.Linear` | Dense linear layer using `braintrace.matmul` |\n",
    "| `braintrace.nn.SignedWLinear` | Linear layer with sign-constrained weights (E/I networks) |\n",
    "| `braintrace.nn.ScaledWSLinear` | Weight-standardized linear layer |\n",
    "| `braintrace.nn.SparseLinear` | Linear layer with sparse connectivity (uses `sparse_matmul`) |\n",
    "| `braintrace.nn.LoRA` | Low-rank adapter layer (uses `lora_matmul`) |\n",
    "| `braintrace.nn.Conv1d` / `Conv2d` / `Conv3d` | Convolutional layers using `braintrace.conv` |\n",
    "| `braintrace.nn.GRUCell` / `LSTMCell` / `ValinaRNNCell` | Recurrent cells with ETP-aware gates |\n",
    "| `braintrace.nn.LeakyRateReadout` | Rate-coded SNN readout |\n",
    "| `braintrace.nn.BatchNorm1d` / `LayerNorm` | Normalisation layers |\n",
    "\n",
    "For the matching low-level API, the user-facing primitive functions are:\n",
    "\n",
    "- `braintrace.matmul(x, w, bias=None)` -- dense matrix multiplication\n",
    "- `braintrace.element_wise(weight, fn=...)` -- element-wise weight ops (gating, learnable thresholds)\n",
    "- `braintrace.conv(x, kernel, bias, ...)` -- convolution\n",
    "- `braintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None)` -- sparse matmul\n",
    "- `braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None)` -- LoRA decomposition\n",
    "\n",
    "Use the `braintrace.nn` layers when you can; reach for the primitive functions when you need a custom layer that participates in online learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0e1f2a3b4c5d6e7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:45.637508Z",
     "iopub.status.busy": "2026-04-17T09:24:45.637379Z",
     "iopub.status.idle": "2026-04-17T09:24:45.640904Z",
     "shell.execute_reply": "2026-04-17T09:24:45.640285Z"
    }
   },
   "outputs": [],
   "source": [
    "class GRUNet(brainstate.nn.Module):\n",
    "    def __init__(self, n_in, n_rec, n_out):\n",
    "        super().__init__()\n",
    "        self.rnn = braintrace.nn.GRUCell(n_in, n_rec)\n",
    "        self.readout = braintrace.nn.Linear(n_rec, n_out)\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.readout(self.rnn(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1f2a3b4c5d6e7f8",
   "metadata": {},
   "source": [
    "The `GRUCell` internally uses `braintrace.matmul` for its weight operations, so all its recurrent parameters automatically participate in online learning. The `Linear` readout also uses ETP primitives, but the compiler will detect that it is not connected to any hidden state and handle it appropriately."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2a3b4c5d6e7f8a9",
   "metadata": {},
   "source": [
    "## 5. Online Learning in 3 Steps\n",
    "\n",
    "Using BrainTrace for online learning follows a simple three-step workflow:\n",
    "\n",
    "1. **Define the model** using `braintrace.nn` modules or manual ETP primitives.\n",
    "2. **Wrap with an algorithm and compile** — choose `D_RTRL` or `ES_D_RTRL` and call `compile_graph()`.\n",
    "3. **Train with standard JAX gradient computation** — eligibility traces are updated inside the wrapped model call.\n",
    "\n",
    "Here is the complete workflow:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a3b4c5d6e7f8a9b0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:45.642700Z",
     "iopub.status.busy": "2026-04-17T09:24:45.642477Z",
     "iopub.status.idle": "2026-04-17T09:24:47.457511Z",
     "shell.execute_reply": "2026-04-17T09:24:47.456529Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GRUNet(\n",
       "  rnn=GRUCell(\n",
       "    in_size=(10,),\n",
       "    out_size=(64,),\n",
       "    state_initializer=ZeroInit(unit=1),\n",
       "    activation=<function tanh at 0x78140b136980>,\n",
       "    Wz=Linear(\n",
       "      in_size=(74,),\n",
       "      out_size=(64,),\n",
       "      w_mask=None,\n",
       "      weight=ParamState(\n",
       "        value={\n",
       "          'bias': ShapedArray(float32[64], weak_type=True),\n",
       "          'weight': ShapedArray(float32[74,64])\n",
       "        }\n",
       "      )\n",
       "    ),\n",
       "    Wr=Linear(\n",
       "      in_size=(74,),\n",
       "      out_size=(64,),\n",
       "      w_mask=None,\n",
       "      weight=ParamState(\n",
       "        value={\n",
       "          'bias': ShapedArray(float32[64], weak_type=True),\n",
       "          'weight': ShapedArray(float32[74,64])\n",
       "        }\n",
       "      )\n",
       "    ),\n",
       "    Wh=Linear(\n",
       "      in_size=(74,),\n",
       "      out_size=(64,),\n",
       "      w_mask=None,\n",
       "      weight=ParamState(\n",
       "        value={\n",
       "          'bias': ShapedArray(float32[64], weak_type=True),\n",
       "          'weight': ShapedArray(float32[74,64])\n",
       "        }\n",
       "      )\n",
       "    ),\n",
       "    h=HiddenState(\n",
       "      value=ShapedArray(float32[64], weak_type=True)\n",
       "    )\n",
       "  ),\n",
       "  readout=Linear(\n",
       "    in_size=(64,),\n",
       "    out_size=(10,),\n",
       "    w_mask=None,\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[10]),\n",
       "        'weight': ShapedArray(float32[64,10])\n",
       "      }\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Step 1: Define model\n",
    "model = GRUNet(10, 64, 10)\n",
    "brainstate.nn.init_all_states(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b4c5d6e7f8a9b0c1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:47.459310Z",
     "iopub.status.busy": "2026-04-17T09:24:47.459059Z",
     "iopub.status.idle": "2026-04-17T09:24:47.517887Z",
     "shell.execute_reply": "2026-04-17T09:24:47.516359Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('rnn', 'Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mv). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.\n",
      "  _emit_no_relation_diag(\n",
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mv (weight=('readout', 'weight')) has no connected hidden states. It will be treated as a non-temporal parameter.\n",
      "  _emit_no_relation_diag(\n"
     ]
    }
   ],
   "source": [
    "# Step 2: Wrap with D-RTRL and compile\n",
    "trainer = braintrace.D_RTRL(model)\n",
    "trainer.compile_graph(jnp.zeros(10))  # provide an example input for shape inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c5d6e7f8a9b0c1d2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:24:47.520581Z",
     "iopub.status.busy": "2026-04-17T09:24:47.520313Z",
     "iopub.status.idle": "2026-04-17T09:24:47.341332Z",
     "shell.execute_reply": "2026-04-17T09:24:47.340466Z"
    }
   },
   "outputs": [],
   "source": [
    "# Step 3: Use standard JAX gradient computation\n",
    "# The eligibility traces are updated inside trainer(x)\n",
    "weights = model.states(brainstate.ParamState)\n",
    "\n",
    "def loss_fn(x):\n",
    "    out = trainer(x)\n",
    "    return jnp.mean(out ** 2)\n",
    "\n",
    "grad_fn = brainstate.transform.grad(loss_fn, weights)\n",
    "grads = grad_fn(jnp.ones(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6e7f8a9b0c1d2e3",
   "metadata": {},
   "source": [
    "**What happens under the hood:**\n",
    "\n",
    "- `compile_graph()` traces the model through JAX, identifies all ETP primitives, and builds the eligibility trace computation graph.\n",
    "- Each call to `trainer(x)` runs the model forward pass *and* updates all eligibility traces.\n",
    "- When you compute `grad(loss_fn, weights)`, the algorithm uses `custom_vjp` to provide gradients that incorporate the eligibility trace information — giving you temporally-aware gradients with $O(1)$ memory per step."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8a9b0c1d2e3f4",
   "metadata": {},
   "source": [
    "## 6. Available Algorithms\n",
    "\n",
    "BrainTrace provides two main online learning algorithms:\n",
    "\n",
    "| Algorithm | Class | Memory Complexity | Compute Complexity | Best For |\n",
    "|---|---|---|---|---|\n",
    "| **D-RTRL** | `braintrace.D_RTRL` | $O(B \\cdot \\theta)$ | $O(B \\cdot I \\cdot O)$ per layer | RNNs (GRU, LSTM) with moderate hidden sizes |\n",
    "| **ES-D-RTRL** | `braintrace.ES_D_RTRL` | $O(B \\cdot N)$ | $O(B \\cdot N)$ per layer | Large-scale SNNs where $N$ is the number of neurons |\n",
    "\n",
    "Where $B$ is the batch size, $\\theta$ is the number of parameters, $I$ and $O$ are input/output dimensions, and $N$ is the number of neurons.\n",
    "\n",
    "### When to use which?\n",
    "\n",
    "- **`D_RTRL`** (also `ParamDimVjpAlgorithm`): Use for **rate-based RNNs** (GRU, LSTM, etc.) where you need accurate temporal gradient propagation. Its $O(\\theta)$ memory cost scales with parameter count, which is acceptable for typical RNN hidden sizes.\n",
    "\n",
    "- **`ES_D_RTRL`** (also `IODimVjpAlgorithm`): Use for **spiking neural networks** (SNNs) or very large recurrent networks. It achieves $O(N)$ complexity by exploiting the element-wise nature of neuronal dynamics, making it much more efficient for networks with many neurons.\n",
    "\n",
    "Both algorithms are used in the same way — just swap the class name:\n",
    "\n",
    "```python\n",
    "# D-RTRL for RNNs\n",
    "trainer = braintrace.D_RTRL(model)\n",
    "\n",
    "# ES-D-RTRL for SNNs\n",
    "trainer = braintrace.ES_D_RTRL(model)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8a9b0c1d2e3f4a5",
   "metadata": {},
   "source": [
    "## 7. Summary\n",
    "\n",
    "Here is a quick recap of the core concepts:\n",
    "\n",
    "| Concept | Description |\n",
    "|---|---|\n",
    "| **Online learning** | Update weights at each time step using eligibility traces, achieving $O(1)$ memory per step |\n",
    "| **ETP primitives** | `braintrace.matmul`, `braintrace.element_wise`, `braintrace.conv` — JAX custom primitives that mark operations for online learning |\n",
    "| **Primitive-based selection** | Use an ETP primitive to include a weight; use regular JAX ops to exclude it |\n",
    "| **`braintrace.nn`** | Pre-built layers (Linear, GRUCell, LSTMCell, Conv) that use ETP primitives internally |\n",
    "| **Compile step** | `trainer.compile_graph(example_input)` — analyzes the jaxpr to build the trace computation graph |\n",
    "| **D_RTRL** | $O(\\theta)$ algorithm for RNNs |\n",
    "| **ES_D_RTRL** | $O(N)$ algorithm for SNNs |"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b0c1d2e3f4a5b6",
   "metadata": {},
   "source": [
    "## 8. Next Steps\n",
    "\n",
    "Now that you understand the core concepts, explore the following tutorials:\n",
    "\n",
    "- **[RNN Online Learning](./rnn_online_learning-en.ipynb)**: A complete example of training a GRU on the copying task using `D_RTRL`, including comparison with BPTT.\n",
    "- **[SNN Online Learning](./snn_online_learning-en.ipynb)**: Training spiking neural networks with `ES_D_RTRL` on neuromorphic datasets.\n",
    "- **[ETP Primitives Deep Dive](../tutorial/etraceop-en.ipynb)**: Detailed guide on using and extending ETP primitives for custom operations.\n",
    "- **[Batching](../tutorial/batching-en.ipynb)**: How to handle batched inputs with online learning.\n",
    "- **[Visualizing the Computation Graph](../tutorial/show_graph-en.ipynb)**: Inspect the compiled eligibility trace graph for debugging."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
