{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60001",
   "metadata": {},
   "source": [
    "# Limitations & Workarounds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60002",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "**braintrace** analyzes the model's [Jaxpr](https://jax.readthedocs.io/en/latest/jaxpr.html) (JAX's intermediate representation) at compile time to automatically derive eligibility trace update rules. This compilation process walks through the traced computation graph to identify the relationships between hidden states, parameters, and the operations that connect them.\n",
    "\n",
    "However, some JAX operations create **sub-Jaxprs** -- separate, nested computation graphs -- that the braintrace compiler cannot traverse. When such operations appear inside the model's `update()` method, the compiler loses visibility into the computation and cannot correctly construct the eligibility trace graph.\n",
    "\n",
    "Understanding these limitations helps you design models that are fully compatible with braintrace's online learning compilation. This tutorial covers the known limitations and provides practical workarounds for each."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60003",
   "metadata": {},
   "source": [
    "## Unsupported JAX Primitives Inside the Model\n",
    "\n",
    "The following JAX control flow primitives are **NOT supported** inside the model's `update()` method:\n",
    "\n",
    "| Primitive | Description | Why it fails |\n",
    "|---|---|---|\n",
    "| `jax.lax.cond` | Conditional execution (if/else) | Creates two branch sub-Jaxprs |\n",
    "| `jax.lax.scan` | Loop with carry state | Creates a body sub-Jaxpr |\n",
    "| `jax.lax.while_loop` | General loops | Creates cond + body sub-Jaxprs |\n",
    "| `jax.vmap` | Vectorized map (nested inside model) | Creates a mapped sub-Jaxpr |\n",
    "\n",
    "Each of these constructs introduces a sub-Jaxpr that the braintrace compiler cannot analyze. When the compiler encounters one of these primitives during graph construction, it will raise a `NotSupportedError` or `CompilationError`.\n",
    "\n",
    "**Important note:** These primitives can still be used *outside* of the model's `update()` method. For example, using `jax.lax.scan` to unroll the model over time steps is perfectly fine -- the restriction only applies to operations *within* the traced computation that connects hidden states to parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60004",
   "metadata": {},
   "source": [
    "## Example of Unsupported Code\n",
    "\n",
    "The following model uses `jax.lax.cond` inside its `update()` method. This will cause a compilation error because the conditional branches create sub-Jaxprs that the compiler cannot traverse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a1b2c3d4e5f60005",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:28:37.459038Z",
     "iopub.status.busy": "2026-04-17T09:28:37.458791Z",
     "iopub.status.idle": "2026-04-17T09:28:39.763395Z",
     "shell.execute_reply": "2026-04-17T09:28:39.762595Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace\n",
    "\n",
    "\n",
    "# THIS WILL NOT WORK: using jax.lax.cond inside update()\n",
    "class BadModel(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(jnp.ones((10, 10)))\n",
    "        self.h = brainstate.HiddenState(jnp.zeros(10))\n",
    "\n",
    "    def update(self, x):\n",
    "        # BAD: jax.lax.cond creates a sub-Jaxpr that the compiler cannot analyze\n",
    "        self.h.value = jax.lax.cond(\n",
    "            jnp.sum(x) > 0,\n",
    "            lambda: jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x),\n",
    "            lambda: self.h.value,\n",
    "        )\n",
    "        return self.h.value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60006",
   "metadata": {},
   "source": [
    "The compiler fails because when it traces `update()`, it sees a `cond` primitive whose true/false branches are opaque sub-Jaxprs. The `braintrace.matmul` call is hidden inside one of those branches, so the compiler cannot discover the relationship between `self.w` and `self.h`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60007",
   "metadata": {},
   "source": [
    "## Workarounds for Conditional Logic\n",
    "\n",
    "When you need branch-like behaviour without a `cond` primitive, the goal is to choose between values **without** producing a sub-Jaxpr that the compiler will see in the hidden-state path.\n",
    "\n",
    "### Strategy 1: `jax.lax.select`\n",
    "\n",
    "`jax.lax.select(predicate, on_true, on_false)` is the lowest-level branch-free selection operator. It compiles directly to the `select_n` primitive -- no `jit`, no `cond`, no sub-Jaxpr. Use it whenever the body of `update()` needs to pick between two precomputed values.\n",
    "\n",
    "> **Note**: in current JAX versions, `jnp.where` is wrapped in a `jit` of `_where` and the compiler treats that as a forbidden sub-Jaxpr when the result feeds a hidden state. Prefer `jax.lax.select` inside `update()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a1b2c3d4e5f60008",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:28:39.765925Z",
     "iopub.status.busy": "2026-04-17T09:28:39.765559Z",
     "iopub.status.idle": "2026-04-17T09:28:39.874383Z",
     "shell.execute_reply": "2026-04-17T09:28:39.873722Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compilation successful.\n"
     ]
    }
   ],
   "source": [
    "# CORRECT: use jax.lax.select (no sub-Jaxpr) instead of jnp.where (which now jits).\n",
    "class GoodModel(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(jnp.ones((10, 10)))\n",
    "        self.h = brainstate.HiddenState(jnp.zeros(10))\n",
    "\n",
    "    def update(self, x):\n",
    "        new_h = jax.nn.tanh(braintrace.matmul(self.h.value, self.w.value) + x)\n",
    "        # jax.lax.select compiles to a single select_n primitive -- the\n",
    "        # compiler can trace right through it.\n",
    "        self.h.value = jax.lax.select(jnp.sum(x) > 0, new_h, self.h.value)\n",
    "        return self.h.value\n",
    "\n",
    "\n",
    "model = GoodModel()\n",
    "brainstate.nn.init_all_states(model)\n",
    "algo = braintrace.D_RTRL(model)\n",
    "algo.compile_graph(jnp.zeros(10))  # works\n",
    "print(\"Compilation successful.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60009",
   "metadata": {},
   "source": [
    "### Strategy 2: Multiplication by a mask\n",
    "\n",
    "For gating-style conditional logic, you can multiply by a binary mask instead of branching. This is particularly natural for spiking neural networks where spike masks are already available.\n",
    "\n",
    "```python\n",
    "# Instead of: jax.lax.cond(spike, lambda: reset_value, lambda: current_value)\n",
    "# Use:        current_value * (1 - spike) + reset_value * spike\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60010",
   "metadata": {},
   "source": [
    "## Shape Compatibility Requirements\n",
    "\n",
    "The braintrace compiler requires that the output of an ETP primitive (e.g., `braintrace.matmul`) be **shape-compatible** with the target hidden state. \"Compatible\" means the shapes must match exactly or be broadcastable to each other.\n",
    "\n",
    "The compiler checks this during relation construction: after identifying an ETP primitive and its associated weight, it traces forward through the Jaxpr to find reachable hidden-state output variables and filters by shape compatibility.\n",
    "\n",
    "If the output of an ETP primitive passes through a shape-changing operation (such as slicing, indexing, or reshaping to an incompatible shape) before reaching the hidden state, the compiler will not be able to establish the connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1b2c3d4e5f60011",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:28:39.876530Z",
     "iopub.status.busy": "2026-04-17T09:28:39.876381Z",
     "iopub.status.idle": "2026-04-17T09:28:39.880790Z",
     "shell.execute_reply": "2026-04-17T09:28:39.880051Z"
    }
   },
   "outputs": [],
   "source": [
    "# Shape mismatch example -- the weight won't be connected to the hidden state\n",
    "class ShapeMismatch(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(jnp.ones((10, 20)))  # outputs dim 20\n",
    "        self.h = brainstate.HiddenState(jnp.zeros(10))       # hidden dim 10\n",
    "\n",
    "    def update(self, x):\n",
    "        # The output shape (20,) doesn't match hidden shape (10,)\n",
    "        # This weight won't be connected to the hidden state\n",
    "        y = braintrace.matmul(x, self.w.value)\n",
    "        self.h.value = y[:10]  # slicing breaks the connection\n",
    "        return self.h.value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60012",
   "metadata": {},
   "source": [
    "In this example, `braintrace.matmul(x, self.w.value)` produces a vector of dimension 20, but the hidden state `self.h` has dimension 10. The slicing operation `y[:10]` is not a simple broadcast -- it fundamentally changes the shape, breaking the connection between the weight and the hidden state in the compiled graph.\n",
    "\n",
    "**Fix:** Ensure that the weight matrix dimensions produce outputs that match the hidden state dimensions directly:\n",
    "\n",
    "```python\n",
    "self.w = brainstate.ParamState(jnp.ones((10, 10)))  # outputs dim 10 to match hidden dim 10\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cb80a69",
   "metadata": {},
   "source": [
    "## The \"Weight -> Weight -> Hidden\" Invariant\n",
    "\n",
    "ETP rules are *local* to a single primitive: each primitive's `xy_to_dw` rule assumes its `x` is externally-supplied data, and its `yw_to_w` rule assumes the path from this primitive's output `y` to a hidden state `h` contains **no other trainable ETP weights**. If a primitive `W1`'s output flows through another non-gradient-enabled ETP primitive `W2` before reaching `h`, the per-primitive rules cannot soundly account for `W1` -- that would either double-count the contribution from `W2` or treat `W2`'s `x` as raw data when it is actually a function of `W1`.\n",
    "\n",
    "The compiler enforces this by **excluding** `W1` from the relation list whenever its only path to `h` passes through another non-gradient-enabled ETP primitive. The diagnostic kind is `RELATION_EXCLUDED_WEIGHT_TO_WEIGHT` and a `UserWarning` is emitted at compile time. The excluded weight is still trainable -- but only via BPTT, not via online learning.\n",
    "\n",
    "The classic example is `braintrace.nn.GRUCell`. It has three internal `Linear` layers (`Wz`, `Wr`, `Wh`), but the compiler records only **two** ETP relations:\n",
    "\n",
    "- `Wz` -- output flows directly into the new hidden state. **Included.**\n",
    "- `Wh` -- output flows directly into the new hidden state. **Included.**\n",
    "- `Wr` -- output is consumed by `Wh`'s matmul (it gates `r * old_h`). **Excluded** with a `RELATION_EXCLUDED_WEIGHT_TO_WEIGHT` warning.\n",
    "\n",
    "This is correct: `Wr`'s contribution to `dL/dh` is already implicit in `Wh`'s gradient (because `Wh`'s input depends on `Wr`), so adding `Wr` separately would double-count. To learn `Wr` online with this architecture, you would need to bundle `Wr` and `Wh` together -- something per-primitive ETP cannot express.\n",
    "\n",
    "### When to use `gradient_enabled=True`\n",
    "\n",
    "The single exception is `etp_elemwise_p` -- the only built-in primitive registered with `gradient_enabled=True`. Element-wise ops (gating biases, learnable thresholds, learnable time constants) are identity-like enough that they may sit on the tail of the `y -> h` walk without breaking the per-primitive assumption: an upstream ETP weight whose output passes through an element-wise op is still recorded as a relation.\n",
    "\n",
    "When registering a custom primitive, leave `gradient_enabled` at its default `False`. Set it to `True` only if your primitive is genuinely identity-like (single weight, no input multiplication, the `xy_to_dw` rule is essentially a passthrough). Setting `gradient_enabled=True` on a *trainable* op -- one with both an `x` and a `w` -- silently re-enables the unsound double-counting and will produce wrong gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "15710888",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:28:39.883674Z",
     "iopub.status.busy": "2026-04-17T09:28:39.883294Z",
     "iopub.status.idle": "2026-04-17T09:28:40.791323Z",
     "shell.execute_reply": "2026-04-17T09:28:40.790378Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GRUCell has 3 ParamStates\n",
      "but the compiler recorded only 2 ETP relations:\n",
      "  - ('Wz', 'weight')  (primitive: etp_mm)\n",
      "  - ('Wh', 'weight')  (primitive: etp_mm)\n",
      "\n",
      "Weight->weight exclusions: 1\n",
      "  - ('Wr', 'weight'): ETP primitive etp_mm (weight=('Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mm). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/d/codes/projects/braintrace/braintrace/_etrace_compiler/hid_param_op.py:772: UserWarning: ETP primitive etp_mm (weight=('Wr', 'weight')) reaches a hidden state only through another trainable ETP primitive (etp_mm). Per the non-parametric-tail invariant this weight is excluded from ETP; learn it by BPTT or rewire the architecture so its output flows directly into a hidden state.\n",
      "  _emit_no_relation_diag(\n"
     ]
    }
   ],
   "source": [
    "# Concrete demonstration: GRUCell yields 2 ETP relations, not 3.\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintrace\n",
    "\n",
    "cell = braintrace.nn.GRUCell(in_size=4, out_size=8)\n",
    "brainstate.nn.init_all_states(cell, batch_size=2)\n",
    "\n",
    "graph = braintrace.compile_etrace_graph(cell, jnp.zeros((2, 4)))\n",
    "\n",
    "print(f\"GRUCell has {len(list(cell.states(brainstate.ParamState)))} ParamStates\")\n",
    "print(f\"but the compiler recorded only {len(graph.hidden_param_op_relations)} ETP relations:\")\n",
    "for r in graph.hidden_param_op_relations:\n",
    "    print(f\"  - {r.weight_path}  (primitive: {r.primitive.name})\")\n",
    "\n",
    "# The third weight (Wr) shows up as a RELATION_EXCLUDED_WEIGHT_TO_WEIGHT diagnostic.\n",
    "from braintrace import DiagnosticKind\n",
    "excluded = [\n",
    "    d for d in graph.diagnostics\n",
    "    if d.kind == DiagnosticKind.RELATION_EXCLUDED_WEIGHT_TO_WEIGHT\n",
    "]\n",
    "print(f\"\\nWeight->weight exclusions: {len(excluded)}\")\n",
    "for d in excluded:\n",
    "    print(f\"  - {d.weight_path}: {d.message}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60013",
   "metadata": {},
   "source": [
    "## Performance Considerations\n",
    "\n",
    "Different online learning algorithms in braintrace have different memory and computational requirements. Choosing the right algorithm is important for scaling to larger models.\n",
    "\n",
    "### Memory complexity comparison\n",
    "\n",
    "| Algorithm | Memory per weight | Total memory | Description |\n",
    "|---|---|---|---|\n",
    "| `D_RTRL` | O(B \\* weight\\_size \\* hidden\\_size) | O(B \\* \\|theta\\| \\* H) | Full eligibility traces |\n",
    "| `ES_D_RTRL` | O(B \\* (in\\_size + out\\_size) \\* hidden\\_size) | O(B \\* (I+O) \\* H) | Factored eligibility traces |\n",
    "| BPTT | O(T \\* model\\_size) | O(T \\* N) | Stores all activations over time |\n",
    "\n",
    "Where B = batch size, H = hidden state dimension, T = sequence length, N = total model size, I = input size, O = output size.\n",
    "\n",
    "### Key tradeoffs\n",
    "\n",
    "- **D_RTRL** provides exact online gradients but can be memory-intensive for large weight matrices. The eligibility trace for each weight matrix has shape `(weight_size, hidden_size)`, which grows quadratically with model size.\n",
    "\n",
    "- **ES_D_RTRL** (factored / IO-dimension algorithm) trades gradient accuracy for memory efficiency. Instead of storing the full eligibility trace, it factors the trace into input-dimension and output-dimension components, reducing memory from O(weight\\_size \\* hidden\\_size) to O((in\\_size + out\\_size) \\* hidden\\_size).\n",
    "\n",
    "- **BPTT** (Backpropagation Through Time) stores all intermediate activations over the unrolled time steps. Memory grows linearly with sequence length T, which can be prohibitive for long sequences.\n",
    "\n",
    "### Recommendations for large models\n",
    "\n",
    "- Use **ES_D_RTRL** instead of D_RTRL when weight matrices are large\n",
    "- Reduce hidden state dimensions where possible\n",
    "- Use sparse operations (`braintrace.sparse_matmul`) to reduce the number of parameters\n",
    "- Consider using `braintrace.lora_matmul` for low-rank weight updates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60014",
   "metadata": {},
   "source": [
    "## Compilation Time\n",
    "\n",
    "The braintrace compiler performs several steps when `compile_graph()` is called:\n",
    "\n",
    "1. **Jaxpr tracing**: JAX traces the model's `update()` method to produce a Jaxpr\n",
    "2. **Relation discovery**: The compiler walks the Jaxpr to find ETP primitives, trace weight origins, and connect them to hidden states\n",
    "3. **Graph construction**: The eligibility trace computation graph is built from the discovered relations\n",
    "\n",
    "This compilation can be slow for complex models, especially on the first call. However:\n",
    "\n",
    "- **Subsequent calls with the same input shapes reuse the compiled graph.** The compilation result is cached, so you only pay the cost once.\n",
    "- **`compile_graph()` should be called once before the training loop**, not inside it. Calling it repeatedly with the same shapes is harmless (it detects the cache hit), but calling it inside a loop adds unnecessary overhead.\n",
    "\n",
    "```python\n",
    "# Good: compile once, then run many steps\n",
    "algo = braintrace.D_RTRL(model)\n",
    "algo.compile_graph(example_input)\n",
    "\n",
    "for step in range(num_steps):\n",
    "    output = algo(input_data[step])  # uses cached compilation\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60015",
   "metadata": {},
   "source": [
    "## What CAN Be Used Inside `update()`\n",
    "\n",
    "The braintrace compiler works with all standard JAX mathematical operations that do not create sub-Jaxprs. These include:\n",
    "\n",
    "**Standard math operations:**\n",
    "- `jnp.add`, `jnp.subtract`, `jnp.multiply`, `jnp.divide`\n",
    "- Element-wise operators: `+`, `-`, `*`, `/`\n",
    "\n",
    "**Matrix operations:**\n",
    "- `@` (matrix multiply operator)\n",
    "- `jnp.dot`, `jnp.matmul`, `jnp.einsum`\n",
    "\n",
    "**Activation functions:**\n",
    "- `jax.nn.tanh`, `jax.nn.relu`, `jax.nn.sigmoid`, `jax.nn.softmax`\n",
    "- `jax.nn.silu`, `jax.nn.gelu`, `jax.nn.leaky_relu`\n",
    "\n",
    "**Shape manipulation:**\n",
    "- `jnp.reshape`, `jnp.transpose`, `jnp.concatenate`\n",
    "- `jnp.expand_dims`, `jnp.squeeze`\n",
    "\n",
    "**Selection and masking:**\n",
    "- `jax.lax.select(predicate, on_true, on_false)` (preferred over `jnp.where` inside `update()`; see *Workarounds* above)\n",
    "- `jnp.clip`, `jnp.maximum`, `jnp.minimum`\n",
    "\n",
    "**Gradient control:**\n",
    "- `jax.lax.stop_gradient` -- useful for detaching parts of the computation\n",
    "\n",
    "**braintrace ETP primitives:**\n",
    "- `braintrace.matmul` -- matrix multiplication with ETP tracking\n",
    "- `braintrace.element_wise` -- element-wise parameter operations with ETP tracking\n",
    "- `braintrace.conv` -- convolution with ETP tracking\n",
    "- `braintrace.sparse_matmul` -- sparse matrix multiplication with ETP tracking\n",
    "- `braintrace.lora_matmul` -- LoRA-style low-rank multiplication with ETP tracking\n",
    "\n",
    "In general, if a JAX operation compiles to a flat sequence of primitives in the Jaxpr (no nested sub-Jaxprs), it is compatible with braintrace."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60016",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "The key limitations and their workarounds are:\n",
    "\n",
    "1. **Avoid `cond`, `scan`, `while_loop`, and nested `vmap` inside the model's `update()` method.** These create sub-Jaxprs that the compiler cannot traverse. Use them freely outside the model (e.g., for time-step unrolling).\n",
    "\n",
    "2. **Use `jnp.where` and masks as alternatives to conditional logic.** Element-wise selection operations are fully supported and produce equivalent results for most use cases.\n",
    "\n",
    "3. **Ensure shape compatibility between ETP primitive outputs and hidden states.** The compiler filters connections by shape -- if shapes don't match or broadcast, the connection won't be established.\n",
    "\n",
    "4. **Per-primitive ETP rules are local.** A weight whose only path to a hidden state passes through another trainable ETP primitive is excluded with a `RELATION_EXCLUDED_WEIGHT_TO_WEIGHT` warning -- it must be learned via BPTT or the architecture must be rewired. `etp_elemwise_p` (the only `gradient_enabled=True` built-in) is the sole exception.\n",
    "\n",
    "5. **Choose the right algorithm based on memory/accuracy tradeoffs.** Use `D_RTRL` for exact gradients with moderate model sizes, and `ES_D_RTRL` for memory-efficient approximate gradients with larger models.\n",
    "\n",
    "6. **Call `compile_graph()` once before training**, not inside the training loop. The compiled graph is cached and reused for inputs of the same shape.\n",
    "\n",
    "7. **The compiler works with all standard JAX mathematical operations.** As long as you avoid the unsupported control flow primitives listed above, your model will compile successfully."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
