{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4",
   "metadata": {},
   "source": [
    "# ETP Primitives Deep Dive"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2c3d4e5",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "ETP (Eligibility Trace Propagation) primitives are JAX custom primitives that **mark weight operations** in the computational graph. They replace the old `ETraceOp` / JIT-name-matching system with a cleaner, more robust approach.\n",
    "\n",
    "Key design principles:\n",
    "\n",
    "- **Type identity, not string matching.** The compiler identifies ETP primitives by checking `eqn.primitive in ETP_PRIMITIVES` — a set-membership test on the primitive object itself. This is more reliable than the old approach of matching JIT function names.\n",
    "\n",
    "- **All JAX rules are auto-derived.** Each primitive's `impl` delegates to standard JAX ops (e.g., `x @ w`, `jax.lax.conv_general_dilated`). The `register_primitive()` helper automatically derives `abstract_eval`, MLIR lowering, JVP, transpose, and batching rules from the implementation. No hand-written derivative formulas are needed.\n",
    "\n",
    "- **Only ETP-specific rules need hand-writing.** Four global dictionaries capture the online-learning-specific logic:\n",
    "  - `ETP_RULES_YW_TO_W` — D-RTRL trace propagation (the $\\mathbf{D}^t \\boldsymbol{\\epsilon}^{t-1}$ term)\n",
    "  - `ETP_RULES_XY_TO_DW` — instantaneous hidden-to-weight Jacobian (the $\\operatorname{diag}(\\mathbf{D}_f^t) \\otimes \\mathbf{x}^t$ term)\n",
    "  - `ETP_RULES_INIT_DRTRL` — D-RTRL parameter-dimension trace initialiser\n",
    "  - `ETP_RULES_INIT_PP` — pp-prop / ES-D-RTRL output-dimension df-trace initialiser\n",
    "\n",
    "- **Primitive-based parameter selection.** A parameter participates in ETP if and only if it flows through an ETP primitive (`braintrace.matmul`, `braintrace.element_wise`, etc.). Parameters used with regular JAX ops are automatically excluded — no special parameter class is needed.\n",
    "\n",
    "- **N-trainable-inputs per primitive (dict rule API).** A single primitive may declare several trainable inputs at once — e.g. `{weight, bias}` for Linear, `{B, A, bias}` for LoRA. The four ETP rules consume and return `Dict[str, Array]` so the executor routes gradients to every owning `ParamState` in one pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3d4e5f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:53.778568Z",
     "iopub.status.busy": "2026-04-18T06:22:53.778250Z",
     "iopub.status.idle": "2026-04-18T06:22:56.070543Z",
     "shell.execute_reply": "2026-04-18T06:22:56.069694Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import braintrace"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4e5f6a7",
   "metadata": {},
   "source": [
    "## The Five Primitive Functions\n",
    "\n",
    "`braintrace` provides five user-facing ETP primitive functions:\n",
    "\n",
    "| Function | Underlying primitives | Purpose |\n",
    "|---|---|---|\n",
    "| `braintrace.matmul` | `etp_mm_p` (batched) / `etp_mv_p` (unbatched) | Dense matrix multiplication |\n",
    "| `braintrace.element_wise` | `etp_elemwise_p` | Element-wise (diagonal) weight ops |\n",
    "| `braintrace.conv` | `etp_conv_p` | Convolution |\n",
    "| `braintrace.sparse_matmul` | `etp_sp_mm_p` / `etp_sp_mv_p` | Sparse matrix multiplication |\n",
    "| `braintrace.lora_matmul` | `etp_lora_mm_p` / `etp_lora_mv_p` | LoRA (Low-Rank Adaptation) matmul |\n",
    "\n",
    "Each function auto-dispatches between batched and unbatched variants based on input dimensionality."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5f6a7b8",
   "metadata": {},
   "source": [
    "### 1. `braintrace.matmul(x, weight, bias=None)` -- Dense Matrix Multiplication\n",
    "\n",
    "Computes $y = x \\, @ \\, w \\; (+ b)$.\n",
    "\n",
    "Auto-dispatches based on `x.ndim`:\n",
    "- `x.ndim >= 2` --> `etp_mm_p` (batched): expects `x` of shape `(batch, in_features)`\n",
    "- `x.ndim == 1` --> `etp_mv_p` (unbatched): expects `x` of shape `(in_features,)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f6a7b8c9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.073864Z",
     "iopub.status.busy": "2026-04-18T06:22:56.073334Z",
     "iopub.status.idle": "2026-04-18T06:22:56.252956Z",
     "shell.execute_reply": "2026-04-18T06:22:56.252077Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batched output shape: (4, 5)\n",
      "Unbatched output shape: (5,)\n"
     ]
    }
   ],
   "source": [
    "# Batched matmul: x has shape (batch, in_features)\n",
    "x_batched = jnp.ones((4, 3))    # batch=4, in_features=3\n",
    "w = jnp.ones((3, 5))            # in_features=3, out_features=5\n",
    "\n",
    "y_batched = braintrace.matmul(x_batched, w)\n",
    "print(\"Batched output shape:\", y_batched.shape)   # (4, 5)\n",
    "\n",
    "# Unbatched matmul: x has shape (in_features,)\n",
    "x_single = jnp.ones((3,))       # in_features=3\n",
    "\n",
    "y_single = braintrace.matmul(x_single, w)\n",
    "print(\"Unbatched output shape:\", y_single.shape)   # (5,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a7b8c9d0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.255033Z",
     "iopub.status.busy": "2026-04-18T06:22:56.254794Z",
     "iopub.status.idle": "2026-04-18T06:22:56.307499Z",
     "shell.execute_reply": "2026-04-18T06:22:56.306626Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With bias: (4, 5)\n"
     ]
    }
   ],
   "source": [
    "# With bias\n",
    "b = jnp.zeros((5,))\n",
    "\n",
    "y_with_bias = braintrace.matmul(x_batched, w, bias=b)\n",
    "print(\"With bias:\", y_with_bias.shape)              # (4, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c9d0e1",
   "metadata": {},
   "source": [
    "### 2. `braintrace.element_wise(weight, fn=lambda w: w)` — Element-wise Operation\n",
    "\n",
    "Applies `fn` to the weight and passes the result through a marker primitive. The operation is treated as *diagonal* in the hidden-state space.\n",
    "\n",
    "$$y = \\texttt{fn}(w)$$\n",
    "\n",
    "`fn` defaults to the identity (`lambda w: w`); supply any JAX-differentiable function when you want a non-trivial transformation.\n",
    "\n",
    "Common use cases:\n",
    "- Gating mechanisms in RNNs (learnable gate biases)\n",
    "- Learnable time constants or thresholds in spiking neural networks\n",
    "- Any parameter that enters the computation element-wise\n",
    "\n",
    "Note: `etp_elemwise_p` is the only primitive registered with `gradient_enabled=True`. The compiler *descends into* it when walking ``y -> h``, so it does not act as a tail boundary for upstream ETP weights. See the *gradient_enabled Flag* section below for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9d0e1f2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.311150Z",
     "iopub.status.busy": "2026-04-18T06:22:56.310960Z",
     "iopub.status.idle": "2026-04-18T06:22:56.379757Z",
     "shell.execute_reply": "2026-04-18T06:22:56.378983Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Identity: [ 0.5 -0.3  0.8  0.1]\n",
      "Sigmoid: [0.62245935 0.4255575  0.6899745  0.5249792 ]\n",
      "Abs: [0.5 0.3 0.8 0.1]\n"
     ]
    }
   ],
   "source": [
    "# Identity (default fn): just marks the weight for ETP\n",
    "w_gate = jnp.array([0.5, -0.3, 0.8, 0.1])\n",
    "\n",
    "y_identity = braintrace.element_wise(w_gate)\n",
    "print(\"Identity:\", y_identity)\n",
    "\n",
    "# With a transformation function\n",
    "y_sigmoid = braintrace.element_wise(w_gate, fn=jax.nn.sigmoid)\n",
    "print(\"Sigmoid:\", y_sigmoid)\n",
    "\n",
    "# With absolute value (e.g., enforcing positive time constants)\n",
    "y_abs = braintrace.element_wise(w_gate, fn=jnp.abs)\n",
    "print(\"Abs:\", y_abs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0e1f2a3",
   "metadata": {},
   "source": [
    "### 3. `braintrace.conv(x, kernel, bias=None, *, strides, padding, ...)` -- Convolution\n",
    "\n",
    "ETP-aware convolution that wraps `jax.lax.conv_general_dilated`. Computes:\n",
    "\n",
    "$$y = \\text{conv}(x, \\text{kernel}) \\; (+ b)$$\n",
    "\n",
    "**Important**: Always expects a batch dimension on `x`.\n",
    "\n",
    "Supports all parameters of `jax.lax.conv_general_dilated`: `strides`, `padding`, `lhs_dilation`, `rhs_dilation`, `feature_group_count`, `batch_group_count`, and `dimension_numbers`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e1f2a3b4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.381550Z",
     "iopub.status.busy": "2026-04-18T06:22:56.381398Z",
     "iopub.status.idle": "2026-04-18T06:22:56.479299Z",
     "shell.execute_reply": "2026-04-18T06:22:56.478311Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conv1D output shape: (2, 16, 8)\n"
     ]
    }
   ],
   "source": [
    "# 1D convolution example\n",
    "# x: (batch, spatial, channels) with dimension_numbers\n",
    "x_1d = jnp.ones((2, 16, 3))         # batch=2, length=16, in_channels=3\n",
    "kernel_1d = jnp.ones((4, 3, 8))     # kernel_size=4, in_channels=3, out_channels=8\n",
    "\n",
    "y_conv = braintrace.conv(\n",
    "    x_1d, kernel_1d,\n",
    "    strides=(1,),\n",
    "    padding='SAME',\n",
    "    dimension_numbers=('NWC', 'WIO', 'NWC'),\n",
    ")\n",
    "print(\"Conv1D output shape:\", y_conv.shape)  # (2, 16, 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f2a3b4c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.481773Z",
     "iopub.status.busy": "2026-04-18T06:22:56.481499Z",
     "iopub.status.idle": "2026-04-18T06:22:56.587670Z",
     "shell.execute_reply": "2026-04-18T06:22:56.587045Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conv2D output shape: (2, 32, 32, 16)\n"
     ]
    }
   ],
   "source": [
    "# 2D convolution example\n",
    "x_2d = jnp.ones((2, 32, 32, 3))          # batch=2, H=32, W=32, in_channels=3\n",
    "kernel_2d = jnp.ones((3, 3, 3, 16))      # kH=3, kW=3, in_channels=3, out_channels=16\n",
    "\n",
    "y_conv2d = braintrace.conv(\n",
    "    x_2d, kernel_2d,\n",
    "    strides=(1, 1),\n",
    "    padding='SAME',\n",
    "    dimension_numbers=('NHWC', 'HWIO', 'NHWC'),\n",
    ")\n",
    "print(\"Conv2D output shape:\", y_conv2d.shape)  # (2, 32, 32, 16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3b4c5d6",
   "metadata": {},
   "source": [
    "### 4. `braintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None)` -- Sparse Matmul\n",
    "\n",
    "ETP-aware sparse matrix multiplication. Computes:\n",
    "\n",
    "$$y = x \\, @ \\, \\text{sparse}(w) \\; (+ b)$$\n",
    "\n",
    "The `sparse_mat` argument provides the sparse structure (indices), while `weight_data` contains only the non-zero values. This is useful for models with sparse connectivity patterns, such as biologically plausible neural networks or graph neural networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b4c5d6e7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:56.589616Z",
     "iopub.status.busy": "2026-04-18T06:22:56.589449Z",
     "iopub.status.idle": "2026-04-18T06:22:57.899681Z",
     "shell.execute_reply": "2026-04-18T06:22:57.898623Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sparse matmul output shape: (4, 50)\n"
     ]
    }
   ],
   "source": [
    "import saiunit as u\n",
    "from saiunit import sparse as ss\n",
    "\n",
    "# Create a sparse connectivity matrix\n",
    "dense_w = jnp.where(\n",
    "    jax.random.uniform(jax.random.PRNGKey(0), (50, 50)) < 0.1,\n",
    "    jax.random.normal(jax.random.PRNGKey(1), (50, 50)),\n",
    "    0.0\n",
    ")\n",
    "sparse_mat = ss.CSR.fromdense(dense_w)\n",
    "\n",
    "# The learnable parameter is just the non-zero data\n",
    "weight_data = sparse_mat.data\n",
    "\n",
    "x_sp = jnp.ones((4, 50))  # batch=4, features=50\n",
    "y_sp = braintrace.sparse_matmul(x_sp, weight_data, sparse_mat=sparse_mat)\n",
    "print(\"Sparse matmul output shape:\", y_sp.shape)  # (4, 50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5d6e7f8",
   "metadata": {},
   "source": [
    "### 5. `braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None)` -- LoRA Matmul\n",
    "\n",
    "Low-Rank Adaptation matmul. Computes:\n",
    "\n",
    "$$y = \\alpha \\cdot x \\, @ \\, B \\, @ \\, A \\; (+ b)$$\n",
    "\n",
    "where $B \\in \\mathbb{R}^{\\text{in} \\times \\text{rank}}$ and $A \\in \\mathbb{R}^{\\text{rank} \\times \\text{out}}$ are low-rank factors. This is useful for parameter-efficient fine-tuning of large models, where only the low-rank factors are trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6e7f8a9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:57.901964Z",
     "iopub.status.busy": "2026-04-18T06:22:57.901698Z",
     "iopub.status.idle": "2026-04-18T06:22:58.464640Z",
     "shell.execute_reply": "2026-04-18T06:22:58.463547Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LoRA output shape: (8, 32)\n",
      "LoRA output (first sample): [ 1.2993842e-04  2.9886682e-03 -5.7171844e-04  1.7000248e-03\n",
      "  2.8961061e-03 -1.4734608e-03  1.5826351e-03  1.1692893e-04\n",
      " -5.4561032e-04 -1.3646147e-03  1.0654312e-04  2.0160272e-03\n",
      "  3.1371990e-03  1.8710974e-03  2.9124888e-03 -8.2047645e-04\n",
      "  9.7736251e-04 -1.1875220e-03 -2.1796541e-03 -5.8191392e-05\n",
      " -1.5415263e-03 -1.3381819e-03 -2.1044153e-03 -8.4472442e-04\n",
      "  1.6430757e-05 -2.9564608e-04  4.5550600e-04 -1.6011632e-03\n",
      " -1.4516843e-03 -1.3209208e-03  3.2850087e-04 -6.0276774e-04]\n"
     ]
    }
   ],
   "source": [
    "in_features, out_features, rank = 64, 32, 4\n",
    "\n",
    "B = jax.random.normal(jax.random.PRNGKey(0), (in_features, rank)) * 0.01\n",
    "A = jax.random.normal(jax.random.PRNGKey(1), (rank, out_features)) * 0.01\n",
    "\n",
    "x_lora = jnp.ones((8, in_features))  # batch=8\n",
    "\n",
    "y_lora = braintrace.lora_matmul(x_lora, B, A, alpha=2.0)\n",
    "print(\"LoRA output shape:\", y_lora.shape)  # (8, 32)\n",
    "print(\"LoRA output (first sample):\", y_lora[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4edae675",
   "metadata": {},
   "source": [
    "## Physical Units (`saiunit` / Quantity) Support\n",
    "\n",
    "Every user-facing ETP function in `braintrace` — `matmul`, `conv`, `element_wise`, `sparse_matmul`, `lora_matmul` — accepts `saiunit.Quantity` inputs transparently. The wrapper\n",
    "\n",
    "1. splits each quantity into a plain-array *mantissa* and a *unit*,\n",
    "2. binds the primitive on the mantissas only,\n",
    "3. re-attaches the combined unit to the result with `u.maybe_decimal`.\n",
    "\n",
    "This keeps the primitives themselves unit-free (JAX sees only arrays), while users write physical quantities naturally. Bias is re-scaled into the combined `x × weight` unit before bind so the addition is dimensionally valid."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1b2c54cf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.466411Z",
     "iopub.status.busy": "2026-04-18T06:22:58.466259Z",
     "iopub.status.idle": "2026-04-18T06:22:58.509627Z",
     "shell.execute_reply": "2026-04-18T06:22:58.508968Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output: [[3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3.]\n",
      " [3. 3. 3. 3. 3.]] A\n",
      "Unit : A\n"
     ]
    }
   ],
   "source": [
    "import saiunit as u\n",
    "\n",
    "# Quantity-valued inputs pass through unchanged.\n",
    "x_q = jnp.ones((4, 3)) * u.volt          # shape (4, 3), unit = V\n",
    "w_q = jnp.ones((3, 5)) * u.siemens       # shape (3, 5), unit = S\n",
    "b_q = jnp.zeros((5,)) * u.amp             # must match V * S = A\n",
    "\n",
    "y_q = braintrace.matmul(x_q, w_q, bias=b_q)\n",
    "print(\"Output:\", y_q)\n",
    "print(\"Unit :\", u.get_unit(y_q))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8a9b0",
   "metadata": {},
   "source": [
    "## JAX Compatibility\n",
    "\n",
    "All ETP primitives are fully compatible with JAX transformations. Since `register_primitive()` auto-derives JIT, grad, vmap, and JVP rules from the implementation, they work seamlessly with the standard JAX API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f8a9b0c1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.512275Z",
     "iopub.status.busy": "2026-04-18T06:22:58.511982Z",
     "iopub.status.idle": "2026-04-18T06:22:58.539485Z",
     "shell.execute_reply": "2026-04-18T06:22:58.538856Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JIT output shape: (4, 5)\n"
     ]
    }
   ],
   "source": [
    "x = jnp.ones((4, 3))\n",
    "w = jnp.ones((3, 5))\n",
    "\n",
    "# ---- JIT compilation ----\n",
    "y_jit = jax.jit(braintrace.matmul)(x, w)\n",
    "print(\"JIT output shape:\", y_jit.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a9b0c1d2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.541162Z",
     "iopub.status.busy": "2026-04-18T06:22:58.541001Z",
     "iopub.status.idle": "2026-04-18T06:22:58.713922Z",
     "shell.execute_reply": "2026-04-18T06:22:58.713110Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient shape: (3, 5)\n",
      "Gradient values:\n",
      " [[4. 4. 4. 4. 4.]\n",
      " [4. 4. 4. 4. 4.]\n",
      " [4. 4. 4. 4. 4.]]\n"
     ]
    }
   ],
   "source": [
    "# ---- Gradient computation ----\n",
    "grad_fn = jax.grad(lambda w: jnp.sum(braintrace.matmul(x, w)))\n",
    "dw = grad_fn(w)\n",
    "print(\"Gradient shape:\", dw.shape)\n",
    "print(\"Gradient values:\\n\", dw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b0c1d2e3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.715971Z",
     "iopub.status.busy": "2026-04-18T06:22:58.715797Z",
     "iopub.status.idle": "2026-04-18T06:22:58.796595Z",
     "shell.execute_reply": "2026-04-18T06:22:58.795935Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vmap output shape: (8, 4, 5)\n"
     ]
    }
   ],
   "source": [
    "# ---- Vectorized mapping (vmap) ----\n",
    "# vmap over a batch of inputs, each of shape (4, 3)\n",
    "xs = jnp.ones((8, 4, 3))  # 8 different batches\n",
    "vmap_fn = jax.vmap(lambda x_i: braintrace.matmul(x_i, w))\n",
    "ys = vmap_fn(xs)\n",
    "print(\"vmap output shape:\", ys.shape)  # (8, 4, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c1d2e3f4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.798370Z",
     "iopub.status.busy": "2026-04-18T06:22:58.798129Z",
     "iopub.status.idle": "2026-04-18T06:22:58.866413Z",
     "shell.execute_reply": "2026-04-18T06:22:58.865537Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JVP primal shape: (4, 5)\n",
      "JVP tangent shape: (4, 5)\n"
     ]
    }
   ],
   "source": [
    "# ---- JVP (forward-mode differentiation) ----\n",
    "primals = (x, w)\n",
    "tangents = (jnp.ones_like(x), jnp.ones_like(w))\n",
    "\n",
    "y_primal, y_tangent = jax.jvp(braintrace.matmul, primals, tangents)\n",
    "print(\"JVP primal shape:\", y_primal.shape)\n",
    "print(\"JVP tangent shape:\", y_tangent.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d2e3f4a5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.868117Z",
     "iopub.status.busy": "2026-04-18T06:22:58.867904Z",
     "iopub.status.idle": "2026-04-18T06:22:58.929017Z",
     "shell.execute_reply": "2026-04-18T06:22:58.928061Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Per-sample gradients shape: (8, 3, 5)\n"
     ]
    }
   ],
   "source": [
    "# ---- Composability: JIT + grad + vmap ----\n",
    "@jax.jit\n",
    "def batched_grad(xs, w):\n",
    "    \"\"\"Compute per-sample gradients w.r.t. the weight.\"\"\"\n",
    "    def single_grad(x_i):\n",
    "        return jax.grad(lambda w_: jnp.sum(braintrace.matmul(x_i, w_)))(w)\n",
    "    return jax.vmap(single_grad)(xs)\n",
    "\n",
    "xs = jnp.ones((8, 4, 3))\n",
    "per_sample_grads = batched_grad(xs, w)\n",
    "print(\"Per-sample gradients shape:\", per_sample_grads.shape)  # (8, 3, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3f4a5b6",
   "metadata": {},
   "source": [
    "## Argument Conventions\n",
    "\n",
    "Every ETP primitive follows specific conventions for its input variables (`invars`) and static parameters. Understanding these conventions is essential when working with the compiler or adding custom primitives.\n",
    "\n",
    "### Invar layout\n",
    "\n",
    "| Primitive | `invars[0]` | `invars[1]` | `invars[2]` | `invars[3]` | Static params |\n",
    "|---|---|---|---|---|---|\n",
    "| `etp_mm_p` / `etp_mv_p` | input `x` | weight `W` | bias `b` (opt) | — | `has_bias` |\n",
    "| `etp_elemwise_p` | processed `y` | — | — | — | (none) |\n",
    "| `etp_conv_p` | input `x` | kernel `W` | bias `b` (opt) | — | `has_bias`, `strides`, `padding`, `lhs_dilation`, `rhs_dilation`, `feature_group_count`, `batch_group_count`, `dimension_numbers` |\n",
    "| `etp_sp_mm_p` / `etp_sp_mv_p` | input `x` | weight data | bias `b` (opt) | — | `sparse_mat`, `has_bias` |\n",
    "| `etp_lora_mm_p` / `etp_lora_mv_p` | input `x` | matrix `B` | matrix `A` | bias `b` (opt) | `alpha`, `has_bias` |\n",
    "\n",
    "### `trainable_invars_fn` — the N-trainable-input contract\n",
    "\n",
    "Instead of hard-coding a single `weight_invar_index`, each primitive exposes a function\n",
    "\n",
    "```python\n",
    "trainable_invars_fn: Callable[[dict], Dict[str, int]]\n",
    "```\n",
    "\n",
    "which maps the equation's static params onto ``{trainable_name: invar_index}``. The compiler calls it at analysis time to discover *every* trainable input and to route gradients to the owning `ParamState` pytree leaf.\n",
    "\n",
    "Built-in examples:\n",
    "\n",
    "| Primitive | `has_bias=False` | `has_bias=True` |\n",
    "|---|---|---|\n",
    "| `etp_mm_p` / `etp_mv_p` | `{'weight': 1}` | `{'weight': 1, 'bias': 2}` |\n",
    "| `etp_conv_p` | `{'weight': 1}` | `{'weight': 1, 'bias': 2}` |\n",
    "| `etp_sp_mm_p` / `etp_sp_mv_p` | `{'weight': 1}` | `{'weight': 1, 'bias': 2}` |\n",
    "| `etp_lora_mm_p` / `etp_lora_mv_p` | `{'lora_b': 1, 'lora_a': 2}` | `{'lora_b': 1, 'lora_a': 2, 'bias': 3}` |\n",
    "| `etp_elemwise_p` | `{'weight': 0}` | — |\n",
    "\n",
    "Notes:\n",
    "- The `has_bias` flag is a static parameter (not a traced value) that controls whether the optional bias argument is present.\n",
    "- For convolution, all `jax.lax.conv_general_dilated` parameters are passed as static params.\n",
    "- `x_invar_index` points to the non-trainable input; `etp_elemwise_p` sets it to `None` because the op has no separate input."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a5b6c7",
   "metadata": {},
   "source": [
    "## Rule Registries (dict API)\n",
    "\n",
    "ETP uses **four** global dictionaries to store operation-specific rules. These are the *only* things that need hand-writing — all standard JAX rules are auto-derived from the implementation function.\n",
    "\n",
    "All four rules operate on ``Dict[str, Array]`` (keyed by the names returned by `trainable_invars_fn`) — *except* `init_pp`, which returns a single output-shaped array because pp-prop factorises the trace as $\\boldsymbol{\\epsilon}_f \\otimes \\boldsymbol{\\epsilon}_x$ and only needs one df-tensor per primitive output.\n",
    "\n",
    "### `ETP_RULES_YW_TO_W` — D-RTRL trace propagation\n",
    "\n",
    "```python\n",
    "yw_to_w(hidden_dim: Array, trace: Dict[str, Array], **static_params) -> Dict[str, Array]\n",
    "```\n",
    "\n",
    "Propagates the hidden-state cotangent $\\partial h/\\partial y$ through the $y \\to W$ chain factor of the D-RTRL term $\\mathbf{D}^t \\boldsymbol{\\epsilon}^{t-1}$. Applied per stored trace key.\n",
    "\n",
    "### `ETP_RULES_XY_TO_DW` — instantaneous hidden-to-weight Jacobian\n",
    "\n",
    "```python\n",
    "xy_to_dw(x: Array, hidden_dim: Array, weights: Dict[str, Array], **static_params) -> Dict[str, Array]\n",
    "```\n",
    "\n",
    "Returns $\\partial h / \\partial W$ for every trainable key. This supplies the $\\operatorname{diag}(\\mathbf{D}_f^t) \\otimes \\mathbf{x}^t$ term in D-RTRL and the solve-time pullback in ES-D-RTRL. Typical implementation: a single fused `jax.vjp` over a dict-valued forward function.\n",
    "\n",
    "### `ETP_RULES_INIT_DRTRL` — D-RTRL trace initialiser\n",
    "\n",
    "```python\n",
    "init_drtrl(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Dict[str, Array]\n",
    "```\n",
    "\n",
    "Returns a zero-filled `Dict[str, Array]` shaped to hold the **parameter-dimension** trace used by `D_RTRL` / `ParamDimVjpAlgorithm`. One leaf per trainable key.\n",
    "\n",
    "### `ETP_RULES_INIT_PP` — pp-prop / ES-D-RTRL df-trace initialiser\n",
    "\n",
    "```python\n",
    "init_pp(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Array\n",
    "```\n",
    "\n",
    "Returns a single zero-filled array shaped to hold the **output-dimension** df trace used by `ES_D_RTRL` / `IODimVjpAlgorithm`. The matching $\\boldsymbol{\\epsilon}_x$ factor is managed separately by the executor's x-trace dictionary.\n",
    "\n",
    "The two `INIT_*` registries exist because the two algorithm families factorise the trace differently. Both are required for a primitive that should support both algorithms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a5b6c7d8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.931264Z",
     "iopub.status.busy": "2026-04-18T06:22:58.931098Z",
     "iopub.status.idle": "2026-04-18T06:22:58.937585Z",
     "shell.execute_reply": "2026-04-18T06:22:58.936469Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All ETP primitives:\n",
      "  etp_conv [batched]\n",
      "  etp_elemwise\n",
      "  etp_lora_mm [batched]\n",
      "  etp_lora_mv\n",
      "  etp_mm [batched]\n",
      "  etp_mv\n",
      "  etp_sp_mm [batched]\n",
      "  etp_sp_mv\n",
      "\n",
      "Trace propagation rules (ETP_RULES_YW_TO_W):\n",
      "  etp_conv\n",
      "  etp_elemwise\n",
      "  etp_lora_mm\n",
      "  etp_lora_mv\n",
      "  etp_mm\n",
      "  etp_mv\n",
      "  etp_sp_mm\n",
      "  etp_sp_mv\n",
      "\n",
      "Weight gradient rules (ETP_RULES_XY_TO_DW):\n",
      "  etp_conv\n",
      "  etp_elemwise\n",
      "  etp_lora_mm\n",
      "  etp_lora_mv\n",
      "  etp_mm\n",
      "  etp_mv\n",
      "  etp_sp_mm\n",
      "  etp_sp_mv\n",
      "\n",
      "D-RTRL init rules (ETP_RULES_INIT_DRTRL):\n",
      "  etp_conv\n",
      "  etp_elemwise\n",
      "  etp_lora_mm\n",
      "  etp_lora_mv\n",
      "  etp_mm\n",
      "  etp_mv\n",
      "  etp_sp_mm\n",
      "  etp_sp_mv\n",
      "\n",
      "pp_prop init rules (ETP_RULES_INIT_PP):\n",
      "  etp_conv\n",
      "  etp_elemwise\n",
      "  etp_lora_mm\n",
      "  etp_lora_mv\n",
      "  etp_mm\n",
      "  etp_mv\n",
      "  etp_sp_mm\n",
      "  etp_sp_mv\n"
     ]
    }
   ],
   "source": [
    "from braintrace._etrace_op import (\n",
    "    ETP_RULES_YW_TO_W,\n",
    "    ETP_RULES_XY_TO_DW,\n",
    "    ETP_RULES_INIT_DRTRL,\n",
    "    ETP_RULES_INIT_PP,\n",
    "    ETP_PRIMITIVES,\n",
    "    BATCHED_PRIMITIVES,\n",
    ")\n",
    "\n",
    "print(\"All ETP primitives:\")\n",
    "for p in sorted(ETP_PRIMITIVES, key=lambda p: p.name):\n",
    "    batched_tag = \" [batched]\" if p in BATCHED_PRIMITIVES else \"\"\n",
    "    print(f\"  {p.name}{batched_tag}\")\n",
    "\n",
    "print(\"\\nTrace propagation rules (ETP_RULES_YW_TO_W):\")\n",
    "for p in sorted(ETP_RULES_YW_TO_W.keys(), key=lambda p: p.name):\n",
    "    print(f\"  {p.name}\")\n",
    "\n",
    "print(\"\\nWeight gradient rules (ETP_RULES_XY_TO_DW):\")\n",
    "for p in sorted(ETP_RULES_XY_TO_DW.keys(), key=lambda p: p.name):\n",
    "    print(f\"  {p.name}\")\n",
    "\n",
    "print(\"\\nD-RTRL init rules (ETP_RULES_INIT_DRTRL):\")\n",
    "for p in sorted(ETP_RULES_INIT_DRTRL.keys(), key=lambda p: p.name):\n",
    "    print(f\"  {p.name}\")\n",
    "\n",
    "print(\"\\npp_prop init rules (ETP_RULES_INIT_PP):\")\n",
    "for p in sorted(ETP_RULES_INIT_PP.keys(), key=lambda p: p.name):\n",
    "    print(f\"  {p.name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6c7d8e9",
   "metadata": {},
   "source": [
    "## Adding a Custom Primitive\n",
    "\n",
    "Adding a new ETP primitive takes only a few steps. Here we create a **scaled matrix multiplication with an optional bias** as an example:\n",
    "\n",
    "$$y = \\text{scale} \\cdot (x \\, @ \\, W) \\; (+ b).$$\n",
    "\n",
    "The example exercises the whole dict rule API: both the `weight` and `bias` branches are wired end-to-end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c7d8e9f0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.939408Z",
     "iopub.status.busy": "2026-04-18T06:22:58.939228Z",
     "iopub.status.idle": "2026-04-18T06:22:58.943089Z",
     "shell.execute_reply": "2026-04-18T06:22:58.942503Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Primitive registered: etp_scaled_mm\n",
      "Type: ETPPrimitive\n"
     ]
    }
   ],
   "source": [
    "import braintrace\n",
    "from braintrace import register_primitive\n",
    "\n",
    "\n",
    "# Step 1: Define the implementation.\n",
    "# Plain JAX function — no special annotations needed.\n",
    "def _scaled_matmul_impl(*args, scale=1.0, has_bias=False):\n",
    "    x, w = args[0], args[1]\n",
    "    y = scale * (x @ w)\n",
    "    if has_bias:\n",
    "        y = y + args[2]\n",
    "    return y\n",
    "\n",
    "\n",
    "# Step 2: Register as an ETP primitive.\n",
    "# register_primitive() returns an ``ETPPrimitive`` and auto-derives all\n",
    "# standard JAX rules (abstract_eval, lowering, JVP, transpose, batching).\n",
    "scaled_mm_p = register_primitive('etp_scaled_mm', _scaled_matmul_impl, batched=True)\n",
    "\n",
    "print(\"Primitive registered:\", scaled_mm_p)\n",
    "print(\"Type:\", type(scaled_mm_p).__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d8e9f0a1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.945237Z",
     "iopub.status.busy": "2026-04-18T06:22:58.945089Z",
     "iopub.status.idle": "2026-04-18T06:22:58.951267Z",
     "shell.execute_reply": "2026-04-18T06:22:58.950682Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "yw_to_w registered:   True\n",
      "xy_to_dw registered:  True\n",
      "init_drtrl registered: True\n",
      "init_pp registered:   True\n"
     ]
    }
   ],
   "source": [
    "# Step 3: Register the four ETP-specific rules (dict API).\n",
    "# Each rule accepts / returns ``Dict[str, Array]`` keyed by the names\n",
    "# in ``trainable_invars_fn`` — here ``'weight'`` and (optionally) ``'bias'``.\n",
    "\n",
    "\n",
    "def _scaled_trainable_invars(params):\n",
    "    \"\"\"Tell the compiler which invars are trainable.\"\"\"\n",
    "    base = {'weight': 1}\n",
    "    if params.get('has_bias', False):\n",
    "        base['bias'] = 2\n",
    "    return base\n",
    "\n",
    "\n",
    "def _scaled_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):\n",
    "    # y = scale * x @ w + b\n",
    "    #   -> ∂y/∂w along the \"out\" axis is scaled by `scale`; the y→w chain\n",
    "    #      link is still elementwise along `out` axis (singleton at axis=-2).\n",
    "    out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}\n",
    "    if has_bias:\n",
    "        out['bias'] = trace['bias'] * hidden_dim\n",
    "    return out\n",
    "\n",
    "\n",
    "def _scaled_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):\n",
    "    # Single fused VJP over a dict-valued forward function — returns\n",
    "    # gradients for both 'weight' and 'bias' in one pass.\n",
    "    def _fwd(w_dict):\n",
    "        y = scale * (x @ w_dict['weight'])\n",
    "        if has_bias:\n",
    "            y = y + w_dict['bias']\n",
    "        return y\n",
    "    _, vjp_fn = jax.vjp(_fwd, weights)\n",
    "    return vjp_fn(hidden_dim)[0]\n",
    "\n",
    "\n",
    "def _scaled_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):\n",
    "    \"\"\"D-RTRL parameter-dim trace: one leaf per trainable key.\"\"\"\n",
    "    batch = x_var.aval.shape[0]\n",
    "    out = {\n",
    "        'weight': jnp.zeros(\n",
    "            (batch, *weight_vars['weight'].aval.shape, num_hidden_state)\n",
    "        )\n",
    "    }\n",
    "    if 'bias' in weight_vars:\n",
    "        out['bias'] = jnp.zeros(\n",
    "            (batch, *weight_vars['bias'].aval.shape, num_hidden_state)\n",
    "        )\n",
    "    return out\n",
    "\n",
    "\n",
    "def _scaled_init_pp(x_var, y_var, weight_vars, num_hidden_state):\n",
    "    \"\"\"pp-prop df trace: single array shaped like the output.\"\"\"\n",
    "    return jnp.zeros(\n",
    "        (*y_var.aval.shape, num_hidden_state),\n",
    "        dtype=y_var.aval.dtype,\n",
    "    )\n",
    "\n",
    "\n",
    "scaled_mm_p.register_etp_rules(\n",
    "    yw_to_w=_scaled_yw_to_w,\n",
    "    xy_to_dw=_scaled_xy_to_dw,\n",
    "    init_drtrl=_scaled_init_drtrl,\n",
    "    init_pp=_scaled_init_pp,\n",
    ")\n",
    "\n",
    "# Shorthand: every ``register_*`` method also exists as a standalone call,\n",
    "# and ``ETPPrimitiveSpec`` (below) bundles the whole thing in one record.\n",
    "\n",
    "print(\"yw_to_w registered:  \", scaled_mm_p in ETP_RULES_YW_TO_W)\n",
    "print(\"xy_to_dw registered: \", scaled_mm_p in ETP_RULES_XY_TO_DW)\n",
    "print(\"init_drtrl registered:\", scaled_mm_p in ETP_RULES_INIT_DRTRL)\n",
    "print(\"init_pp registered:  \", scaled_mm_p in ETP_RULES_INIT_PP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e9f0a1b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:58.953466Z",
     "iopub.status.busy": "2026-04-18T06:22:58.953219Z",
     "iopub.status.idle": "2026-04-18T06:22:59.139119Z",
     "shell.execute_reply": "2026-04-18T06:22:59.138200Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output shape : (4, 5)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Matches 2·xw : True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With bias    : [6.1 6.1 6.1 6.1 6.1]\n"
     ]
    }
   ],
   "source": [
    "# Step 4: Use the custom primitive via ``primitive.bind()``.\n",
    "\n",
    "x = jnp.ones((4, 3))\n",
    "w = jnp.ones((3, 5))\n",
    "\n",
    "y = scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)\n",
    "y_expected = 2.0 * (x @ w)\n",
    "\n",
    "print(\"Output shape :\", y.shape)\n",
    "print(\"Matches 2·xw :\", bool(jnp.allclose(y, y_expected)))\n",
    "\n",
    "# With bias:\n",
    "b = jnp.full((5,), 0.1)\n",
    "y_bias = scaled_mm_p.bind(x, w, b, scale=2.0, has_bias=True)\n",
    "print(\"With bias    :\", y_bias[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "f0a1b2c3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:59.141017Z",
     "iopub.status.busy": "2026-04-18T06:22:59.140842Z",
     "iopub.status.idle": "2026-04-18T06:22:59.259288Z",
     "shell.execute_reply": "2026-04-18T06:22:59.258589Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JIT works: True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Grad shape: (3, 5)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vmap output shape: (8, 4, 5)\n"
     ]
    }
   ],
   "source": [
    "# All JAX transformations work automatically for the custom primitive.\n",
    "\n",
    "# JIT\n",
    "y_jit = jax.jit(lambda x, w: scaled_mm_p.bind(x, w, scale=2.0, has_bias=False))(x, w)\n",
    "print(\"JIT works:\", bool(jnp.allclose(y_jit, y_expected)))\n",
    "\n",
    "# Grad\n",
    "dw = jax.grad(lambda w: jnp.sum(scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)))(w)\n",
    "print(\"Grad shape:\", dw.shape)\n",
    "\n",
    "# Vmap\n",
    "xs = jnp.ones((8, 4, 3))\n",
    "ys = jax.vmap(lambda xi: scaled_mm_p.bind(xi, w, scale=2.0, has_bias=False))(xs)\n",
    "print(\"Vmap output shape:\", ys.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a63fbcca",
   "metadata": {},
   "source": [
    "> **Compiler integration.** The class-style registration above is enough for direct `primitive.bind()` use, JIT, grad, vmap, and JVP. For the primitive to be discovered by the *ETP compiler* (`compile_etrace_graph`, `D_RTRL`, `ES_D_RTRL`), you must also publish the `trainable_invars_fn` and the invar layout via an `ETPPrimitiveSpec` — see the next section. `_scaled_trainable_invars` defined above is re-used there."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b60b171",
   "metadata": {},
   "source": [
    "## Spec-based Registration\n",
    "\n",
    "When a primitive is intended for the ETP compiler, use the **spec** form. An `ETPPrimitiveSpec` bundles the implementation, all four rules, the invar/outvar layout, and the `trainable_invars_fn` into one frozen dataclass. Passing it to `register_primitive_spec` wires everything up and records the spec in `ETP_PRIMITIVE_SPECS` so the compiler can query it via `get_primitive_spec`.\n",
    "\n",
    "Spec fields:\n",
    "\n",
    "| Field | Purpose |\n",
    "|---|---|\n",
    "| `name` | Primitive name |\n",
    "| `impl` | Plain JAX forward function |\n",
    "| `yw_to_w`, `xy_to_dw`, `init_drtrl`, `init_pp` | The four ETP rules |\n",
    "| `trainable_invars_fn` | ``params -> {trainable_name: invar_index}`` — required |\n",
    "| `x_invar_index` | Position of the non-trainable input, or `None` for identity-like ops |\n",
    "| `y_outvar_index` | Position of `y` in `eqn.outvars` (default `0`) |\n",
    "| `batched` | Batched-input primitive? |\n",
    "| `gradient_enabled` | Compiler traverses this primitive when walking `y → h` (default `False`; set only for identity-like ops) |\n",
    "\n",
    "The spec form is equivalent to the class-based form — pick whichever style fits your codebase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "44161966",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:59.261052Z",
     "iopub.status.busy": "2026-04-18T06:22:59.260875Z",
     "iopub.status.idle": "2026-04-18T06:22:59.268492Z",
     "shell.execute_reply": "2026-04-18T06:22:59.267403Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable (no bias): {'weight': 1}\n",
      "trainable (bias)   : {'weight': 1, 'bias': 2}\n",
      "spec bind output  : (4, 5)\n"
     ]
    }
   ],
   "source": [
    "import braintrace\n",
    "\n",
    "\n",
    "def _spec_impl(*args, scale=1.0, has_bias=False):\n",
    "    x, w = args[0], args[1]\n",
    "    y = scale * (x @ w)\n",
    "    if has_bias:\n",
    "        y = y + args[2]\n",
    "    return y\n",
    "\n",
    "\n",
    "def _spec_trainable_invars(params):\n",
    "    base = {'weight': 1}\n",
    "    if params.get('has_bias', False):\n",
    "        base['bias'] = 2\n",
    "    return base\n",
    "\n",
    "\n",
    "def _spec_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):\n",
    "    out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}\n",
    "    if has_bias:\n",
    "        out['bias'] = trace['bias'] * hidden_dim\n",
    "    return out\n",
    "\n",
    "\n",
    "def _spec_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):\n",
    "    def _fwd(w_dict):\n",
    "        y = scale * (x @ w_dict['weight'])\n",
    "        if has_bias:\n",
    "            y = y + w_dict['bias']\n",
    "        return y\n",
    "    _, vjp_fn = jax.vjp(_fwd, weights)\n",
    "    return vjp_fn(hidden_dim)[0]\n",
    "\n",
    "\n",
    "def _spec_init_drtrl(x_var, y_var, weight_vars, n):\n",
    "    batch = x_var.aval.shape[0]\n",
    "    out = {\n",
    "        'weight': jnp.zeros(\n",
    "            (batch, *weight_vars['weight'].aval.shape, n)\n",
    "        )\n",
    "    }\n",
    "    if 'bias' in weight_vars:\n",
    "        out['bias'] = jnp.zeros(\n",
    "            (batch, *weight_vars['bias'].aval.shape, n)\n",
    "        )\n",
    "    return out\n",
    "\n",
    "\n",
    "def _spec_init_pp(x_var, y_var, weight_vars, n):\n",
    "    return jnp.zeros((*y_var.aval.shape, n), dtype=y_var.aval.dtype)\n",
    "\n",
    "\n",
    "spec = braintrace.ETPPrimitiveSpec(\n",
    "    name='etp_spec_demo',\n",
    "    impl=_spec_impl,\n",
    "    yw_to_w=_spec_yw_to_w,\n",
    "    xy_to_dw=_spec_xy_to_dw,\n",
    "    init_drtrl=_spec_init_drtrl,\n",
    "    init_pp=_spec_init_pp,\n",
    "    trainable_invars_fn=_spec_trainable_invars,\n",
    "    x_invar_index=0,   # ``invars[0]`` is the input; trainable invars are\n",
    "                       # fully described by ``trainable_invars_fn`` above.\n",
    "    batched=True,\n",
    ")\n",
    "\n",
    "spec_p = braintrace.register_primitive_spec(spec)\n",
    "\n",
    "# The spec is recoverable from the primitive at any time:\n",
    "assert braintrace.get_primitive_spec(spec_p) is spec\n",
    "\n",
    "# Quick sanity check of the invar layout:\n",
    "print(\"trainable (no bias):\", spec.resolve_trainable_invars({'has_bias': False}))\n",
    "print(\"trainable (bias)   :\", spec.resolve_trainable_invars({'has_bias': True}))\n",
    "\n",
    "# And the primitive binds normally:\n",
    "y_spec = spec_p.bind(x, w, scale=3.0, has_bias=False)\n",
    "print(\"spec bind output  :\", y_spec.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af6deb1d",
   "metadata": {},
   "source": [
    "## The `gradient_enabled` Flag\n",
    "\n",
    "`register_primitive()` accepts a `gradient_enabled` keyword (default `False`). It controls how the compiler treats this primitive when walking from a weight's output back to a hidden state.\n",
    "\n",
    "| `gradient_enabled` | Compiler behaviour | Example |\n",
    "|---|---|---|\n",
    "| `False` (default) | Treats the primitive as a **tail boundary**. A preceding ETP weight whose only path to a hidden state passes through this primitive is **excluded** from ETP, because per-primitive ETP rules cannot express weight-then-weight composition. | All trainable matmul/conv/sparse/LoRA primitives use this. |\n",
    "| `True` | The primitive is **identity-like** and may sit on the tail of the `y -> h` walk. Its presence does not exclude an upstream ETP weight. | Only `etp_elemwise_p` -- intended for gating biases, learnable thresholds, etc. |\n",
    "\n",
    "Use `gradient_enabled=True` only when the primitive's `xy_to_dw` rule is itself an identity-like passthrough; mark all genuinely *trainable* ops with the default. The \"weight -> weight -> hidden\" exclusion is what makes per-primitive ETP rules sound -- see ``advanced/limitations.ipynb`` for a worked example with ``GRUCell`` (3 Linears, only 2 ETP relations)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9391990a",
   "metadata": {},
   "source": [
    "## Integrating a Primitive with Online Learning\n",
    "\n",
    "Marking a weight operation with a `braintrace.*` primitive is the *only* thing a model has to do to opt that parameter into online learning. The compiler then walks the jaxpr, finds every ETP primitive, connects it to the downstream hidden states, and builds the eligibility-trace machinery for either `D_RTRL` (parameter-dim trace) or `ES_D_RTRL` / `pp_prop` (IO-dim trace).\n",
    "\n",
    "**Rule of thumb**\n",
    "\n",
    "| Goal | Use |\n",
    "|---|---|\n",
    "| Include a parameter in online learning | `braintrace.matmul(x, W)` (or `conv`, `sparse_matmul`, `lora_matmul`, `element_wise`) |\n",
    "| Exclude a parameter from online learning | regular JAX op: `x @ W`, `lax.conv_general_dilated`, … |\n",
    "\n",
    "The short example below wires a vanilla RNN into `D_RTRL`: only the recurrent weight is marked with `braintrace.matmul`, so only it receives an eligibility trace. The input weight uses plain `@` and is learned by BPTT through the unrolled scan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "5a9cdb2b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-18T06:22:59.270891Z",
     "iopub.status.busy": "2026-04-18T06:22:59.270644Z",
     "iopub.status.idle": "2026-04-18T06:22:57.323492Z",
     "shell.execute_reply": "2026-04-18T06:22:57.322941Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled ETP relations: 1\n",
      "   primitive = etp_mm   trainable keys = ['weight']\n"
     ]
    }
   ],
   "source": [
    "import brainstate\n",
    "\n",
    "\n",
    "class TinyRNN(brainstate.nn.Module):\n",
    "    def __init__(self, in_dim=4, hid_dim=6):\n",
    "        super().__init__()\n",
    "        self.in_dim = in_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        # Recurrent weight: ETP-enabled (online learning via D-RTRL).\n",
    "        self.W_rec = brainstate.ParamState(\n",
    "            0.1 * jax.random.normal(jax.random.PRNGKey(0), (hid_dim, hid_dim))\n",
    "        )\n",
    "        # Input weight: plain matmul, learned via BPTT instead.\n",
    "        self.W_in = brainstate.ParamState(\n",
    "            0.1 * jax.random.normal(jax.random.PRNGKey(1), (in_dim, hid_dim))\n",
    "        )\n",
    "\n",
    "    def init_state(self, batch_size=None, **kwargs):\n",
    "        # ``HiddenState`` is what the ETP compiler traces through.\n",
    "        self.h = brainstate.HiddenState(\n",
    "            jnp.zeros((batch_size or 1, self.hid_dim))\n",
    "        )\n",
    "\n",
    "    def update(self, x):\n",
    "        # W_in is NOT marked -> excluded from ETP.\n",
    "        input_drive = x @ self.W_in.value\n",
    "        # W_rec IS marked -> included in ETP.\n",
    "        rec_drive = braintrace.matmul(self.h.value, self.W_rec.value)\n",
    "        self.h.value = jax.nn.tanh(input_drive + rec_drive)\n",
    "        return self.h.value\n",
    "\n",
    "\n",
    "model = TinyRNN(in_dim=4, hid_dim=6)\n",
    "brainstate.nn.init_all_states(model, batch_size=2)\n",
    "\n",
    "# Wrap the model in a D-RTRL algorithm and compile the ETP graph.\n",
    "alg = braintrace.D_RTRL(model)\n",
    "alg.compile_graph(jnp.zeros((2, model.in_dim)))\n",
    "\n",
    "print(\"Compiled ETP relations:\", len(alg.graph.hidden_param_op_relations))\n",
    "for rel in alg.graph.hidden_param_op_relations:\n",
    "    print(\"   primitive =\", rel.primitive.name,\n",
    "          \"  trainable keys =\", list(rel.trainable_vars.keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d5",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "ETP primitives provide a clean, extensible foundation for online learning in recurrent networks:\n",
    "\n",
    "- **8 built-in primitives** cover the most common use cases: dense matmul (mm/mv), element-wise ops, convolution, sparse matmul (mm/mv), and LoRA matmul (mm/mv).\n",
    "\n",
    "- **Dict rule API** — every primitive declares its full set of trainable inputs via `trainable_invars_fn`, and the four ETP rules consume and return `Dict[str, Array]`. A single primitive can own several `ParamState` objects (e.g. weight + bias, or $B + A + b$ in LoRA) and the executor routes gradients to each in one pass.\n",
    "\n",
    "- **Custom primitives can be added in a few dozen lines**: implement the forward function, call `register_primitive` (class style) or build an `ETPPrimitiveSpec` (compiler-ready), then hand-write the four ETP rules.\n",
    "\n",
    "- **All JAX transformations (JIT, grad, vmap, JVP) work automatically** — only the four online-learning-specific rules need hand-writing.\n",
    "\n",
    "- **Parameter selection is primitive-based** — every `brainstate.ParamState` is eligible for ETP, and participation depends only on whether a `braintrace.*` ETP primitive consumed it. Use `gradient_enabled=True` exclusively for identity-like ops such as `etp_elemwise_p`.\n",
    "\n",
    "- **Saiunit quantities** are handled transparently by every user-facing wrapper.\n",
    "\n",
    "Where to look for the math:\n",
    "\n",
    "| Rule | Algorithm term | Source with derivation |\n",
    "|---|---|---|\n",
    "| `xy_to_dw` | $\\operatorname{diag}(\\mathbf{D}_f^t) \\otimes \\mathbf{x}^t$ | docstrings in `braintrace/_etrace_op/{dense,conv,elemwise,sparse,lora}.py` |\n",
    "| `yw_to_w` | $\\mathbf{D}^t \\boldsymbol{\\epsilon}^{t-1}$ ($y \\to W$ link) | same files |\n",
    "| `init_drtrl` | param-dim trace shape | same files |\n",
    "| `init_pp` | output-dim df-trace shape | same files |\n",
    "\n",
    "Further reading: `advanced/limitations.ipynb` explains the non-parametric-tail invariant and walks through `GRUCell` (3 Linears, only 2 ETP relations)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
