{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loops and Conditionals\n",
    "\n",
    "This tutorial covers state-aware control flow primitives in `brainstate.transform`. These APIs provide JAX-compatible loops and conditionals while safely handling `State` objects.\n",
    "\n",
    "We'll explore three categories of control flow:\n",
    "\n",
    "1. **Loop Transformations**: `scan`, `checkpointed_scan`, `for_loop`, `checkpointed_for_loop`\n",
    "2. **While Loops**: `while_loop`, `bounded_while_loop`\n",
    "3. **Conditional Control Flow**: `cond`, `switch`, `ifelse`\n",
    "\n",
    "Each API is designed to work seamlessly with BrainState's state management system while maintaining JAX's functional programming paradigm."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.458509Z",
     "start_time": "2025-10-11T07:36:27.453344Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate.transform import (\n",
    "    scan,\n",
    "    checkpointed_scan,\n",
    "    for_loop,\n",
    "    checkpointed_for_loop,\n",
    "    while_loop,\n",
    "    bounded_while_loop,\n",
    "    cond,\n",
    "    switch,\n",
    "    ifelse,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.469729Z",
     "start_time": "2025-10-11T07:36:27.466242Z"
    }
   },
   "outputs": [],
   "source": [
    "# Import ProgressBar\n",
    "from brainstate.transform import ProgressBar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Loop Transformations\n",
    "\n",
    "Loop transformations provide efficient iteration over sequences with state tracking. They compile to a single JAX primitive, reducing compilation overhead."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 `scan`: Stateful Scanning with Carry\n",
    "\n",
    "`scan` is the fundamental loop primitive that:\n",
    "- Iterates over a sequence along the leading axis\n",
    "- Maintains a \"carry\" value that threads through iterations\n",
    "- Collects outputs at each step\n",
    "- Properly handles `State` objects\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "scan(\n",
    "    f: Callable[[Carry, X], Tuple[Carry, Y]],\n",
    "    init: Carry,\n",
    "    xs: X,\n",
    "    length: int | None = None,\n",
    "    reverse: bool = False,\n",
    "    unroll: int | bool = 1,\n",
    "    pbar: ProgressBar | int | None = None,\n",
    ") -> Tuple[Carry, Y]\n",
    "```\n",
    "\n",
    "**Parameters:**\n",
    "- `f`: Function of type `(carry, x) -> (new_carry, output)`\n",
    "- `init`: Initial carry value\n",
    "- `xs`: Sequence to iterate over (along axis 0)\n",
    "- `length`: Optional iteration count (inferred from `xs` if not provided)\n",
    "- `reverse`: If True, iterate in reverse order\n",
    "- `unroll`: Number of iterations to unroll (1=no unrolling, True=full unrolling)\n",
    "- `pbar`: Optional progress bar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.508259Z",
     "start_time": "2025-10-11T07:36:27.474255Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input sequence: [1. 2. 3. 4. 5.]\n",
      "Final sum: 15.0\n",
      "Cumulative sums: [ 1.  3.  6. 10. 15.]\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Basic scan with carry\n",
    "def cumsum_body(carry, x):\n",
    "    \"\"\"Accumulate sum and return both new carry and current sum.\"\"\"\n",
    "    new_carry = carry + x\n",
    "    return new_carry, new_carry\n",
    "\n",
    "\n",
    "xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])\n",
    "final_sum, cumulative_sums = scan(cumsum_body, init=0.0, xs=xs)\n",
    "\n",
    "print(\"Input sequence:\", xs)\n",
    "print(\"Final sum:\", final_sum)\n",
    "print(\"Cumulative sums:\", cumulative_sums)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.558109Z",
     "start_time": "2025-10-11T07:36:27.516082Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data: [2. 4. 4. 4. 5. 5. 7. 9.]\n",
      "\n",
      "Running mean: [2.        3.        3.3333333 3.5       3.8       4.        4.428571\n",
      " 5.       ]\n",
      "Running variance: [0.         1.         0.8888889  0.75       0.96000004 1.\n",
      " 1.9591838  4.        ]\n",
      "\n",
      "Final statistics:\n",
      "  Count: 8\n",
      "  Mean: 5.0\n",
      "  Variance: 4.0\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Scan with stateful computation\n",
    "class RunningStats(brainstate.nn.Module):\n",
    "    \"\"\"Maintain running mean and variance.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.count = brainstate.ShortTermState(jnp.array(0))\n",
    "        self.mean = brainstate.ShortTermState(jnp.array(0.0))\n",
    "        self.m2 = brainstate.ShortTermState(jnp.array(0.0))  # sum of squared differences\n",
    "\n",
    "    def update(self, x):\n",
    "        \"\"\"Update statistics with new value using Welford's algorithm.\"\"\"\n",
    "        self.count.value = self.count.value + 1\n",
    "        delta = x - self.mean.value\n",
    "        self.mean.value = self.mean.value + delta / self.count.value\n",
    "        delta2 = x - self.mean.value\n",
    "        self.m2.value = self.m2.value + delta * delta2\n",
    "\n",
    "        variance = self.m2.value / self.count.value\n",
    "        return {'mean': self.mean.value, 'var': variance}\n",
    "\n",
    "\n",
    "stats = RunningStats()\n",
    "\n",
    "\n",
    "def stats_body(carry, x):\n",
    "    result = stats.update(x)\n",
    "    return carry, result\n",
    "\n",
    "\n",
    "data = jnp.array([2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0])\n",
    "_, history = scan(stats_body, init=None, xs=data)\n",
    "\n",
    "print(\"Data:\", data)\n",
    "print(\"\\nRunning mean:\", history['mean'])\n",
    "print(\"Running variance:\", history['var'])\n",
    "print(\"\\nFinal statistics:\")\n",
    "print(f\"  Count: {stats.count.value}\")\n",
    "print(f\"  Mean: {stats.mean.value}\")\n",
    "print(f\"  Variance: {stats.m2.value / stats.count.value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.627141Z",
     "start_time": "2025-10-11T07:36:27.564309Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: [1. 2. 3. 4. 5.]\n",
      "Forward cumsum: [ 1.  3.  6. 10. 15.]\n",
      "Backward cumsum: [15. 14. 12.  9.  5.]\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Reverse scan\n",
    "def reverse_cumsum(carry, x):\n",
    "    new_carry = carry + x\n",
    "    return new_carry, new_carry\n",
    "\n",
    "\n",
    "xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])\n",
    "_, forward_sums = scan(reverse_cumsum, 0.0, xs, reverse=False)\n",
    "_, backward_sums = scan(reverse_cumsum, 0.0, xs, reverse=True)\n",
    "\n",
    "print(\"Input:\", xs)\n",
    "print(\"Forward cumsum:\", forward_sums)\n",
    "print(\"Backward cumsum:\", backward_sums)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Progress Bar with `scan`\n",
    "\n",
    "The `pbar` parameter enables progress tracking during long-running scans. You can:\n",
    "- Pass a `ProgressBar` instance for full control over display options\n",
    "- Pass an integer for quick setup (updates every N iterations)\n",
    "- Customize the description with static or dynamic messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.702814Z",
     "start_time": "2025-10-11T07:36:27.635100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Simple progress bar (update every 20 iterations) ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0852be231c60421ebbf4b09c979f3444",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Final result: -4.0076361074170563e-07\n"
     ]
    }
   ],
   "source": [
    "# Example 4: Progress bar with scan - simple integer freq\n",
    "print(\"\\n=== Simple progress bar (update every 20 iterations) ===\")\n",
    "\n",
    "\n",
    "def expensive_computation(carry, x):\n",
    "    \"\"\"Simulate expensive computation.\"\"\"\n",
    "    # Some computation\n",
    "    result = carry + jnp.sin(x) * jnp.cos(x)\n",
    "    return result, result\n",
    "\n",
    "\n",
    "# Create long sequence\n",
    "long_sequence = jnp.linspace(0, 10 * jnp.pi, 100)\n",
    "\n",
    "# Use integer for simple progress bar (updates every 20 iterations)\n",
    "final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=20)\n",
    "print(f\"\\nFinal result: {final}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.773098Z",
     "start_time": "2025-10-11T07:36:27.710435Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Custom progress bar with ProgressBar ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "974cc1dfce874846b325d26c67507daa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Completed! Final result: -4.0076361074170563e-07\n"
     ]
    }
   ],
   "source": [
    "# Example 5: Progress bar with custom ProgressBar instance\n",
    "print(\"\\n=== Custom progress bar with ProgressBar ===\")\n",
    "\n",
    "# Create ProgressBar with custom settings\n",
    "pbar = ProgressBar(freq=10, desc=\"Processing sequence\")\n",
    "\n",
    "final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=pbar)\n",
    "print(f\"\\nCompleted! Final result: {final}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.865101Z",
     "start_time": "2025-10-11T07:36:27.780505Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Dynamic progress bar with loop state ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebe0f81f749c4f9ab52684b73857659f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Optimization completed!\n",
      "Final parameters: -0.04794257506728172\n",
      "Final loss: 1.41942298412323\n",
      "Best loss achieved: 3.858334093820304e-05\n"
     ]
    }
   ],
   "source": [
    "# Example 6: Dynamic progress bar description based on loop state\n",
    "print(\"\\n=== Dynamic progress bar with loop state ===\")\n",
    "\n",
    "\n",
    "class OptimizationTracker(brainstate.nn.Module):\n",
    "    \"\"\"Track optimization progress.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.best_loss = brainstate.ShortTermState(jnp.array(float('inf')))\n",
    "\n",
    "    def step(self, params, x):\n",
    "        # Compute loss\n",
    "        loss = jnp.sum((params - x) ** 2)\n",
    "        # Update best\n",
    "        self.best_loss.value = jnp.minimum(self.best_loss.value, loss)\n",
    "        # Update parameters\n",
    "        new_params = params - 0.1 * 2 * (params - x)\n",
    "        return new_params, loss\n",
    "\n",
    "\n",
    "tracker = OptimizationTracker()\n",
    "\n",
    "\n",
    "def scan_body_with_tracking(params, x):\n",
    "    return tracker.step(params, x)\n",
    "\n",
    "\n",
    "# Define dynamic description\n",
    "def format_progress(data):\n",
    "    \"\"\"Format progress with current loss and best loss.\"\"\"\n",
    "    return {\n",
    "        \"iter\": data[\"i\"],\n",
    "        \"loss\": data[\"y\"],\n",
    "        \"best\": tracker.best_loss.value\n",
    "    }\n",
    "\n",
    "\n",
    "pbar_dynamic = ProgressBar(\n",
    "    freq=15,\n",
    "    desc=(\"Iter {iter:3d} | Loss: {loss:.4f} | Best: {best:.4f}\", format_progress)\n",
    ")\n",
    "\n",
    "targets = jax.random.normal(jax.random.PRNGKey(42), (100,))\n",
    "init_params = jnp.array(0.0)\n",
    "\n",
    "final_params, loss_history = scan(\n",
    "    scan_body_with_tracking,\n",
    "    init=init_params,\n",
    "    xs=targets,\n",
    "    pbar=pbar_dynamic\n",
    ")\n",
    "\n",
    "print(f\"\\nOptimization completed!\")\n",
    "print(f\"Final parameters: {final_params}\")\n",
    "print(f\"Final loss: {loss_history[-1]}\")\n",
    "print(f\"Best loss achieved: {tracker.best_loss.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 `checkpointed_scan`: Memory-Efficient Scanning\n",
    "\n",
    "`checkpointed_scan` is a memory-optimized version of `scan` that uses gradient checkpointing. This is crucial for:\n",
    "- Long sequences where storing all intermediate activations is memory-prohibitive **during gradient computation**\n",
    "- Trading computation time for memory during backpropagation\n",
    "- **Memory efficiency is achieved by only storing checkpoints at regular intervals during the forward pass, then recomputing intermediate values during the backward pass when needed**\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "checkpointed_scan(\n",
    "    f: Callable[[Carry, X], Tuple[Carry, Y]],\n",
    "    init: Carry,\n",
    "    xs: X,\n",
    "    length: Optional[int] = None,\n",
    "    base: int = 16,\n",
    "    pbar: Optional[ProgressBar | int] = None,\n",
    ") -> Tuple[Carry, Y]\n",
    "```\n",
    "\n",
    "**Key parameter:**\n",
    "- `base`: Checkpointing base (default=16). Smaller values save more memory but increase recomputation during backward pass. The implementation uses a hierarchical checkpointing scheme where `max_steps = base^k` for some k.\n",
    "\n",
    "**Memory savings during gradient computation:**\n",
    "- Regular `scan`: Stores **all** intermediate activations → O(n) memory for sequence length n\n",
    "- `checkpointed_scan`: Stores only checkpoints → O(log_base(n)) memory\n",
    "- During backward pass: Recomputes intermediate values between checkpoints as needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:27.985824Z",
     "start_time": "2025-10-11T07:36:27.884202Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequence length: 100\n",
      "Hidden size: 32\n",
      "Final hidden shape: (32,)\n",
      "All hiddens shape: (100, 32)\n",
      "\n",
      "Checkpointing configuration:\n",
      "  Base: 8 (stores checkpoint every 8 steps)\n",
      "  Memory saved: Stores ~12 checkpoints instead of 100 activations\n",
      "  During backprop: Recomputes activations between checkpoints as needed\n"
     ]
    }
   ],
   "source": [
    "# Example: Memory-efficient scan for gradient computation\n",
    "class RecurrentCell(brainstate.nn.Module):\n",
    "    \"\"\"Simple recurrent cell with hidden state.\"\"\"\n",
    "\n",
    "    def __init__(self, hidden_size):\n",
    "        super().__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.weight = brainstate.ParamState(jax.random.normal(\n",
    "            jax.random.PRNGKey(0), (hidden_size, hidden_size)\n",
    "        ))\n",
    "\n",
    "    def step(self, hidden, x):\n",
    "        \"\"\"Single recurrent step.\"\"\"\n",
    "        new_hidden = jnp.tanh(jnp.dot(self.weight.value, hidden) + x)\n",
    "        return new_hidden\n",
    "\n",
    "\n",
    "# Create a cell and input sequence\n",
    "cell = RecurrentCell(hidden_size=32)\n",
    "sequence_length = 100\n",
    "inputs = jax.random.normal(jax.random.PRNGKey(1), (sequence_length, 32))\n",
    "\n",
    "\n",
    "def rnn_body(hidden, x):\n",
    "    new_hidden = cell.step(hidden, x)\n",
    "    return new_hidden, new_hidden\n",
    "\n",
    "\n",
    "# Use checkpointed scan for memory efficiency during gradient computation\n",
    "init_hidden = jnp.zeros(32)\n",
    "final_hidden, all_hiddens = checkpointed_scan(\n",
    "    rnn_body,\n",
    "    init=init_hidden,\n",
    "    xs=inputs,\n",
    "    base=8  # Checkpoint every 8 steps\n",
    ")\n",
    "\n",
    "print(f\"Sequence length: {sequence_length}\")\n",
    "print(f\"Hidden size: {cell.hidden_size}\")\n",
    "print(f\"Final hidden shape: {final_hidden.shape}\")\n",
    "print(f\"All hiddens shape: {all_hiddens.shape}\")\n",
    "print(f\"\\nCheckpointing configuration:\")\n",
    "print(f\"  Base: 8 (stores checkpoint every 8 steps)\")\n",
    "print(f\"  Memory saved: Stores ~{sequence_length // 8} checkpoints instead of {sequence_length} activations\")\n",
    "print(f\"  During backprop: Recomputes activations between checkpoints as needed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Progress Bar with `checkpointed_scan`\n",
    "\n",
    "`checkpointed_scan` also supports progress bars, which is especially useful for very long sequences where you want to monitor progress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.123314Z",
     "start_time": "2025-10-11T07:36:27.992066Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Checkpointed scan with progress bar ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7912e3dc35f4a68943cdd4c2b4d3dc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Processed 500 operations\n",
      "Final state: 493.9846496582031\n",
      "Results shape: (500,)\n"
     ]
    }
   ],
   "source": [
    "# Example: Progress bar with checkpointed_scan\n",
    "print(\"\\n=== Checkpointed scan with progress bar ===\")\n",
    "\n",
    "\n",
    "class LongRunningComputation(brainstate.nn.Module):\n",
    "    \"\"\"Simulate a long-running computation.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.total_ops = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def process(self, state, x):\n",
    "        self.total_ops.value = self.total_ops.value + 1\n",
    "        # Some computation\n",
    "        new_state = state + jnp.tanh(x)\n",
    "        output = jnp.sin(new_state) * jnp.cos(x)\n",
    "        return new_state, output\n",
    "\n",
    "\n",
    "long_comp = LongRunningComputation()\n",
    "\n",
    "\n",
    "def body(state, x):\n",
    "    return long_comp.process(state, x)\n",
    "\n",
    "\n",
    "# Long sequence\n",
    "very_long_sequence = jnp.linspace(0, 20 * jnp.pi, 500)\n",
    "\n",
    "# Progress bar that updates every 50 iterations\n",
    "pbar_checkpointed = ProgressBar(\n",
    "    freq=50,\n",
    "    desc=\"Checkpointed scan progress\"\n",
    ")\n",
    "\n",
    "final_state, results = checkpointed_scan(\n",
    "    body,\n",
    "    init=0.0,\n",
    "    xs=very_long_sequence,\n",
    "    base=10,\n",
    "    pbar=pbar_checkpointed\n",
    ")\n",
    "\n",
    "print(f\"\\nProcessed {long_comp.total_ops.value} operations\")\n",
    "print(f\"Final state: {final_state}\")\n",
    "print(f\"Results shape: {results.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 `for_loop`: Simplified Loop Without Carry\n",
    "\n",
    "`for_loop` provides a simpler interface when you don't need an explicit carry value. It:\n",
    "- Accepts variadic arguments that are sliced along axis 0\n",
    "- **Collects and returns outputs from each iteration** - the return value from your function at each timestep is saved and stacked into the final output array\n",
    "- Internally uses `scan` with `None` as the carry\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "for_loop(\n",
    "    f: Callable[..., Y],\n",
    "    *xs,\n",
    "    length: Optional[int] = None,\n",
    "    reverse: bool = False,\n",
    "    unroll: int | bool = 1,\n",
    "    pbar: Optional[ProgressBar | int] = None\n",
    ") -> Y\n",
    "```\n",
    "\n",
    "**Key differences from scan:**\n",
    "- Function signature is `(*xs) -> output` instead of `(carry, x) -> (carry, output)`\n",
    "- No carry value to manage\n",
    "- **Important**: The return value at **each iteration** is collected and stacked along axis 0 to form the final output. This means if your function returns a scalar at each step, `for_loop` returns a 1D array; if it returns a vector of shape `(d,)`, the output will be shape `(n, d)` where `n` is the number of iterations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.173030Z",
     "start_time": "2025-10-11T07:36:28.135321Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: [1. 2. 3. 4.]\n",
      "y: [2. 3. 4. 5.]\n",
      "z: [0.5 1.  1.5 2. ]\n",
      "x * y + z: [ 2.5  7.  13.5 22. ]\n",
      "\n",
      "Notice: for_loop collected 4 outputs (one per iteration)\n",
      "Each element results[i] = xs[i] * ys[i] + zs[i]\n",
      "Output shape: (4,) (stacked along axis 0)\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Understanding output collection in for_loop\n",
    "def compute(x, y, z):\n",
    "    \"\"\"Combine three inputs.\"\"\"\n",
    "    return x * y + z\n",
    "\n",
    "\n",
    "xs = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "ys = jnp.array([2.0, 3.0, 4.0, 5.0])\n",
    "zs = jnp.array([0.5, 1.0, 1.5, 2.0])\n",
    "\n",
    "# for_loop collects the output from EACH iteration\n",
    "results = for_loop(compute, xs, ys, zs)\n",
    "\n",
    "print(\"x:\", xs)\n",
    "print(\"y:\", ys)\n",
    "print(\"z:\", zs)\n",
    "print(\"x * y + z:\", results)\n",
    "print(f\"\\nNotice: for_loop collected {len(results)} outputs (one per iteration)\")\n",
    "print(f\"Each element results[i] = xs[i] * ys[i] + zs[i]\")\n",
    "print(f\"Output shape: {results.shape} (stacked along axis 0)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.231614Z",
     "start_time": "2025-10-11T07:36:28.196557Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data: [1. 2. 3. 4. 5. 6.]\n",
      "Running averages: [1.  1.5 2.  2.5 3.  3.5]\n",
      "\n",
      "Final state: total=21.0, count=6\n",
      "Final average: 3.5\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Stateful for_loop\n",
    "class Accumulator(brainstate.nn.Module):\n",
    "    \"\"\"Simple accumulator that tracks total and count.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.total = brainstate.ShortTermState(jnp.array(0.0))\n",
    "        self.count = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def process(self, x):\n",
    "        self.total.value = self.total.value + x\n",
    "        self.count.value = self.count.value + 1\n",
    "        return self.total.value / self.count.value  # running average\n",
    "\n",
    "\n",
    "acc = Accumulator()\n",
    "\n",
    "data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])\n",
    "running_averages = for_loop(acc.process, data)\n",
    "\n",
    "print(\"Data:\", data)\n",
    "print(\"Running averages:\", running_averages)\n",
    "print(f\"\\nFinal state: total={acc.total.value}, count={acc.count.value}\")\n",
    "print(f\"Final average: {acc.total.value / acc.count.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Progress Bar with `for_loop`\n",
    "\n",
    "`for_loop` also supports progress bars. This is particularly useful when processing large batches of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.321947Z",
     "start_time": "2025-10-11T07:36:28.239620Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== For loop with progress bar ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e80536280c8048f1a5dfafd523275826",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Processed 200 items\n",
      "Sum of inputs: 9.437446594238281\n",
      "Processed data shape: (200,)\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Progress bar with for_loop - simple case\n",
    "print(\"\\n=== For loop with progress bar ===\")\n",
    "\n",
    "\n",
    "class DataProcessor(brainstate.nn.Module):\n",
    "    \"\"\"Process data with progress tracking.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.processed_count = brainstate.ShortTermState(jnp.array(0))\n",
    "        self.sum_val = brainstate.ShortTermState(jnp.array(0.0))\n",
    "\n",
    "    def process_item(self, x):\n",
    "        self.processed_count.value = self.processed_count.value + 1\n",
    "        self.sum_val.value = self.sum_val.value + x\n",
    "        # Simulate some processing\n",
    "        result = jnp.exp(x) / (1 + jnp.exp(x))  # sigmoid\n",
    "        return result\n",
    "\n",
    "\n",
    "processor = DataProcessor()\n",
    "\n",
    "# Create dataset\n",
    "dataset = jax.random.normal(jax.random.PRNGKey(123), (200,))\n",
    "\n",
    "# Use simple integer for progress updates\n",
    "processed = for_loop(processor.process_item, dataset, pbar=25)\n",
    "\n",
    "print(f\"\\nProcessed {processor.processed_count.value} items\")\n",
    "print(f\"Sum of inputs: {processor.sum_val.value}\")\n",
    "print(f\"Processed data shape: {processed.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.437240Z",
     "start_time": "2025-10-11T07:36:28.342956Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== For loop with dynamic progress description ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db59883248a44c56a65e11e13eab574e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/150 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Final statistics:\n",
      "  Mean: 0.9223809242248535\n",
      "  Variance: 3.488231897354126\n",
      "  Count: 150\n"
     ]
    }
   ],
   "source": [
    "# Example 4: For loop with dynamic progress description\n",
    "print(\"\\n=== For loop with dynamic progress description ===\")\n",
    "\n",
    "\n",
    "class BatchProcessor(brainstate.nn.Module):\n",
    "    \"\"\"Process batches with statistics.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.mean = brainstate.ShortTermState(jnp.array(0.0))\n",
    "        self.variance = brainstate.ShortTermState(jnp.array(0.0))\n",
    "        self.count = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def update(self, x):\n",
    "        self.count.value = self.count.value + 1\n",
    "        delta = x - self.mean.value\n",
    "        self.mean.value = self.mean.value + delta / self.count.value\n",
    "        self.variance.value = self.variance.value + delta * (x - self.mean.value)\n",
    "        return x ** 2\n",
    "\n",
    "\n",
    "batch_proc = BatchProcessor()\n",
    "\n",
    "\n",
    "def format_batch_progress(data):\n",
    "    \"\"\"Show current statistics.\"\"\"\n",
    "    return {\n",
    "        \"n\": data[\"i\"],\n",
    "        \"mean\": batch_proc.mean.value,\n",
    "        \"var\": batch_proc.variance.value / jnp.maximum(batch_proc.count.value, 1)\n",
    "    }\n",
    "\n",
    "\n",
    "pbar_batch = ProgressBar(\n",
    "    freq=20,\n",
    "    desc=(\"Batch {n:3d} | Mean: {mean:+.3f} | Var: {var:.3f}\", format_batch_progress)\n",
    ")\n",
    "\n",
    "batch_data = jax.random.normal(jax.random.PRNGKey(456), (150,)) * 2.0 + 1.0\n",
    "\n",
    "squared = for_loop(batch_proc.update, batch_data, pbar=pbar_batch)\n",
    "\n",
    "print(f\"\\nFinal statistics:\")\n",
    "print(f\"  Mean: {batch_proc.mean.value}\")\n",
    "print(f\"  Variance: {batch_proc.variance.value / batch_proc.count.value}\")\n",
    "print(f\"  Count: {batch_proc.count.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 `checkpointed_for_loop`: Memory-Efficient For Loop\n",
    "\n",
    "The checkpointed version of `for_loop` combines the simplicity of `for_loop` with the memory efficiency of checkpointing **during gradient computation**.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "checkpointed_for_loop(\n",
    "    f: Callable[..., Y],\n",
    "    *xs: X,\n",
    "    length: Optional[int] = None,\n",
    "    base: int = 16,\n",
    "    pbar: Optional[ProgressBar | int] = None,\n",
    ") -> Y\n",
    "```\n",
    "\n",
    "**Memory efficiency during gradient computation:**\n",
    "- Like `checkpointed_scan`, this variant significantly reduces memory usage during backpropagation\n",
    "- Essential for training models with very long sequences where storing all intermediate activations would cause out-of-memory errors\n",
    "- The `base` parameter controls the memory/computation tradeoff: smaller base = less memory but more recomputation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.535539Z",
     "start_time": "2025-10-11T07:36:28.448245Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Signal length: 200\n",
      "Smoothed signal shape: (200,)\n",
      "Original signal range: [-1.401, 1.408]\n",
      "Smoothed signal range: [-1.131, 1.200]\n"
     ]
    }
   ],
   "source": [
    "# Example: Processing long sequence with state\n",
    "class ExpMovingAverage(brainstate.nn.Module):\n",
    "    \"\"\"Exponential moving average.\"\"\"\n",
    "\n",
    "    def __init__(self, alpha=0.1):\n",
    "        super().__init__()\n",
    "        self.alpha = alpha\n",
    "        self.ema = brainstate.ShortTermState(jnp.array(0.0))\n",
    "        self.initialized = brainstate.ShortTermState(jnp.array(False))\n",
    "\n",
    "    def update(self, x):\n",
    "        # Initialize with first value\n",
    "        self.ema.value = jnp.where(\n",
    "            self.initialized.value,\n",
    "            self.alpha * x + (1 - self.alpha) * self.ema.value,\n",
    "            x\n",
    "        )\n",
    "        self.initialized.value = True\n",
    "        return self.ema.value\n",
    "\n",
    "\n",
    "ema = ExpMovingAverage(alpha=0.3)\n",
    "\n",
    "# Generate noisy signal\n",
    "signal = jnp.sin(jnp.linspace(0, 4 * jnp.pi, 200)) + 0.2 * brainstate.random.normal(size=(200,))\n",
    "\n",
    "# Process with checkpointed for_loop\n",
    "smoothed = checkpointed_for_loop(ema.update, signal, base=10)\n",
    "\n",
    "print(f\"Signal length: {len(signal)}\")\n",
    "print(f\"Smoothed signal shape: {smoothed.shape}\")\n",
    "print(f\"Original signal range: [{signal.min():.3f}, {signal.max():.3f}]\")\n",
    "print(f\"Smoothed signal range: [{smoothed.min():.3f}, {smoothed.max():.3f}]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Progress Bar with `checkpointed_for_loop`\n",
    "\n",
    "`checkpointed_for_loop` supports progress bars to help track processing of very long sequences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.668982Z",
     "start_time": "2025-10-11T07:36:28.544045Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Checkpointed for loop with progress bar ===\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "606fe45ed64a424cb5d069316cb814eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Stream processed!\n",
      "Final running average: 0.050981856882572174\n",
      "Smoothed stream shape: (1000,)\n",
      "First 5 values: [ 0.03755957 -0.01195641  0.0011875   0.03787947  0.04933156]\n",
      "Last 5 values: [ 0.10296391  0.07546234 -0.00500105  0.06068815  0.05098186]\n"
     ]
    }
   ],
   "source": [
    "# Example: Progress bar with checkpointed_for_loop\n",
    "print(\"\\n=== Checkpointed for loop with progress bar ===\")\n",
    "\n",
    "\n",
    "class StreamProcessor(brainstate.nn.Module):\n",
    "    \"\"\"Process streaming data.\"\"\"\n",
    "\n",
    "    def __init__(self, momentum=0.9):\n",
    "        super().__init__()\n",
    "        self.momentum = momentum\n",
    "        self.running_avg = brainstate.ShortTermState(jnp.array(0.0))\n",
    "\n",
    "    def process(self, x):\n",
    "        # Update exponential moving average\n",
    "        self.running_avg.value = (\n",
    "            self.momentum * self.running_avg.value + (1 - self.momentum) * x\n",
    "        )\n",
    "        return self.running_avg.value\n",
    "\n",
    "\n",
    "stream_proc = StreamProcessor(momentum=0.95)\n",
    "\n",
    "# Generate long data stream\n",
    "data_stream = jax.random.normal(jax.random.PRNGKey(789), (1000,))\n",
    "\n",
    "# Progress bar with count parameter (updates exactly 10 times)\n",
    "pbar_stream = ProgressBar(\n",
    "    count=10,\n",
    "    desc=\"Processing data stream\"\n",
    ")\n",
    "\n",
    "smoothed_stream = checkpointed_for_loop(\n",
    "    stream_proc.process,\n",
    "    data_stream,\n",
    "    base=20,\n",
    "    pbar=pbar_stream\n",
    ")\n",
    "\n",
    "print(f\"\\nStream processed!\")\n",
    "print(f\"Final running average: {stream_proc.running_avg.value}\")\n",
    "print(f\"Smoothed stream shape: {smoothed_stream.shape}\")\n",
    "print(f\"First 5 values: {smoothed_stream[:5]}\")\n",
    "print(f\"Last 5 values: {smoothed_stream[-5:]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.5 Comparison: `scan` vs `for_loop`\n",
    "\n",
    "When to use each:\n",
    "\n",
    "**Use `scan` when:**\n",
    "- You need to thread a carry value through iterations\n",
    "- Implementing recurrent patterns (RNNs, state machines)\n",
    "- You want explicit control over the accumulator\n",
    "\n",
    "**Use `for_loop` when:**\n",
    "- No carry value is needed\n",
    "- Processing independent items with side effects (state updates)\n",
    "- Simpler, more Pythonic syntax is preferred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.751592Z",
     "start_time": "2025-10-11T07:36:28.693469Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Powers of 2 (first 10 values):\n",
      "scan result:     [  1   2   4   8  16  32  64 128 256 512]\n",
      "for_loop result: [  1   2   4   8  16  32  64 128 256 512]\n"
     ]
    }
   ],
   "source": [
    "# Comparison example: Computing powers of 2\n",
    "\n",
    "# Using scan: carry explicitly tracks the power\n",
    "def scan_version(n):\n",
    "    def body(carry, _):\n",
    "        return carry * 2, carry\n",
    "\n",
    "    _, powers = scan(body, init=1, xs=jnp.arange(n))\n",
    "    return powers\n",
    "\n",
    "\n",
    "# Using for_loop with state: state tracks the power\n",
    "class PowerTracker(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.current = brainstate.ShortTermState(jnp.array(1))\n",
    "\n",
    "    def next_power(self, _):\n",
    "        result = self.current.value\n",
    "        self.current.value = self.current.value * 2\n",
    "        return result\n",
    "\n",
    "\n",
    "def forloop_version(n):\n",
    "    tracker = PowerTracker()\n",
    "    return for_loop(tracker.next_power, jnp.arange(n))\n",
    "\n",
    "\n",
    "n = 10\n",
    "print(f\"Powers of 2 (first {n} values):\")\n",
    "print(\"scan result:    \", scan_version(n))\n",
    "print(\"for_loop result:\", forloop_version(n))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. While Loops\n",
    "\n",
    "While loops provide conditional iteration where the number of iterations is not known in advance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 `while_loop`: Dynamic Conditional Iteration\n",
    "\n",
    "`while_loop` executes a body function repeatedly while a condition remains true. This is the stateful version of `jax.lax.while_loop`.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "while_loop(\n",
    "    cond_fun: Callable[[T], BooleanNumeric],\n",
    "    body_fun: Callable[[T], T],\n",
    "    init_val: T\n",
    ") -> T\n",
    "```\n",
    "\n",
    "**Parameters:**\n",
    "- `cond_fun`: Function that returns True to continue looping\n",
    "- `body_fun`: Function that updates the loop value\n",
    "- `init_val`: Initial loop value\n",
    "\n",
    "**Important constraints:**\n",
    "- `cond_fun` cannot modify state (read-only)\n",
    "- Loop value must maintain fixed shape and dtype\n",
    "- Not reverse-mode differentiable (use `bounded_while_loop` instead)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.781109Z",
     "start_time": "2025-10-11T07:36:28.757599Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First power of 2 above 1000: 1024\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Simple while loop - find first power of 2 above threshold\n",
    "def find_power_of_2_above(threshold):\n",
    "    def cond_fn(val):\n",
    "        return val < threshold\n",
    "\n",
    "    def body(val):\n",
    "        return val * 2\n",
    "\n",
    "    return while_loop(cond_fn, body, init_val=1)\n",
    "\n",
    "\n",
    "threshold = 1000\n",
    "result = find_power_of_2_above(threshold)\n",
    "print(f\"First power of 2 above {threshold}: {result}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.819016Z",
     "start_time": "2025-10-11T07:36:28.789114Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing sqrt(2)...\n",
      "Result: 1.4142135381698608\n",
      "Actual sqrt(2): 1.4142135381698608\n",
      "Error: 0.0\n",
      "Iterations: 4\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Stateful while loop - iterative refinement\n",
    "class IterativeRefiner(brainstate.nn.Module):\n",
    "    \"\"\"Iteratively refine an estimate using Newton's method.\"\"\"\n",
    "\n",
    "    def __init__(self, target):\n",
    "        super().__init__()\n",
    "        self.target = target\n",
    "        self.iterations = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def refine(self, x):\n",
    "        \"\"\"Newton's method step for computing sqrt(target).\"\"\"\n",
    "        self.iterations.value = self.iterations.value + 1\n",
    "        return 0.5 * (x + self.target / x)\n",
    "\n",
    "\n",
    "# Compute square root of 2 using Newton's method\n",
    "refiner = IterativeRefiner(target=2.0)\n",
    "\n",
    "\n",
    "def cond_f(x):\n",
    "    # Continue until error is small enough\n",
    "    return jnp.abs(x * x - refiner.target) > 1e-6\n",
    "\n",
    "\n",
    "def body(x):\n",
    "    return refiner.refine(x)\n",
    "\n",
    "\n",
    "result = while_loop(cond_f, body, init_val=1.0)\n",
    "\n",
    "print(f\"Computing sqrt(2)...\")\n",
    "print(f\"Result: {result}\")\n",
    "print(f\"Actual sqrt(2): {jnp.sqrt(2.0)}\")\n",
    "print(f\"Error: {jnp.abs(result - jnp.sqrt(2.0))}\")\n",
    "print(f\"Iterations: {refiner.iterations.value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.866573Z",
     "start_time": "2025-10-11T07:36:28.827403Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collatz sequence starting from 27:\n",
      "  Converged to: 1\n",
      "  Steps taken: 111\n",
      "  Maximum value reached: 9232\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Complex loop value (pytree)\n",
    "class Collatz(brainstate.nn.Module):\n",
    "    \"\"\"Track Collatz sequence statistics.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.max_value = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def step(self, n):\n",
    "        self.max_value.value = jnp.maximum(self.max_value.value, n)\n",
    "        return jnp.where(n % 2 == 0, n // 2, 3 * n + 1)\n",
    "\n",
    "\n",
    "collatz = Collatz()\n",
    "\n",
    "\n",
    "def collatz_cond(state):\n",
    "    n, steps = state\n",
    "    return n > 1\n",
    "\n",
    "\n",
    "def collatz_body(state):\n",
    "    n, steps = state\n",
    "    return collatz.step(n), steps + 1\n",
    "\n",
    "\n",
    "start_value = 27\n",
    "final_n, total_steps = while_loop(\n",
    "    collatz_cond,\n",
    "    collatz_body,\n",
    "    init_val=(start_value, 0)\n",
    ")\n",
    "\n",
    "print(f\"Collatz sequence starting from {start_value}:\")\n",
    "print(f\"  Converged to: {final_n}\")\n",
    "print(f\"  Steps taken: {total_steps}\")\n",
    "print(f\"  Maximum value reached: {collatz.max_value.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 `bounded_while_loop`: While Loop with Maximum Steps\n",
    "\n",
    "`bounded_while_loop` adds a maximum iteration limit to while loops. This is important for:\n",
    "- Preventing infinite loops\n",
    "- Enabling reverse-mode differentiation (unlike `while_loop`)\n",
    "- Providing compilation time guarantees\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "bounded_while_loop(\n",
    "    cond_fun: Callable[[T], BooleanNumeric],\n",
    "    body_fun: Callable[[T], T],\n",
    "    init_val: T,\n",
    "    *,\n",
    "    max_steps: int,\n",
    "    base: int = 16,\n",
    ")\n",
    "```\n",
    "\n",
    "**Key parameters:**\n",
    "- `max_steps`: Maximum number of iterations before termination\n",
    "- `base`: Compilation/runtime tradeoff (default=16)\n",
    "  - Larger base = faster compilation, slightly slower runtime\n",
    "  - Smaller base = slower compilation, faster runtime\n",
    "  - Compile time scales with `math.ceil(math.log(max_steps, base))`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:28.926870Z",
     "start_time": "2025-10-11T07:36:28.876485Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Minimizing f(x) = (x - 3)^2\n",
      "Starting from x = 0.0\n",
      "Final x: 3.001088857650757\n",
      "Target x: 3.0\n",
      "Error: 0.001088857650756836\n",
      "Iterations used: 252 / 100\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Gradient descent with bounded iterations\n",
    "class GradientDescent(brainstate.nn.Module):\n",
    "    \"\"\"Simple gradient descent optimizer.\"\"\"\n",
    "\n",
    "    def __init__(self, learning_rate=0.1):\n",
    "        super().__init__()\n",
    "        self.lr = learning_rate\n",
    "        self.steps = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def step(self, x):\n",
    "        # Gradient of f(x) = (x - 3)^2\n",
    "        grad = 2 * (x - 3.0)\n",
    "        self.steps.value = self.steps.value + 1\n",
    "        return x - self.lr * grad\n",
    "\n",
    "\n",
    "optimizer = GradientDescent(learning_rate=0.1)\n",
    "\n",
    "\n",
    "def converged(x):\n",
    "    # Continue if far from optimum\n",
    "    return jnp.abs(x - 3.0) > 1e-4\n",
    "\n",
    "\n",
    "result = bounded_while_loop(\n",
    "    converged,\n",
    "    optimizer.step,\n",
    "    init_val=0.0,\n",
    "    max_steps=100\n",
    ")\n",
    "\n",
    "print(f\"Minimizing f(x) = (x - 3)^2\")\n",
    "print(f\"Starting from x = 0.0\")\n",
    "print(f\"Final x: {result}\")\n",
    "print(f\"Target x: 3.0\")\n",
    "print(f\"Error: {jnp.abs(result - 3.0)}\")\n",
    "print(f\"Iterations used: {optimizer.steps.value} / 100\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.124899Z",
     "start_time": "2025-10-11T07:36:28.934876Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Base  2: result=206, iterations=50, recursion_depth≈7\n",
      "Base  8: result=3746, iterations=50, recursion_depth≈3\n",
      "Base 16: result=3346, iterations=50, recursion_depth≈2\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Comparing different base values\n",
    "class Counter(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.count = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def increment(self, x):\n",
    "        self.count.value = self.count.value + 1\n",
    "        return x + 1\n",
    "\n",
    "\n",
    "def compare_base_values():\n",
    "    max_steps = 100\n",
    "\n",
    "    for base in [2, 8, 16]:\n",
    "        counter = Counter()\n",
    "\n",
    "        result = bounded_while_loop(\n",
    "            lambda x: x < 50,\n",
    "            counter.increment,\n",
    "            init_val=0,\n",
    "            max_steps=max_steps,\n",
    "            base=base\n",
    "        )\n",
    "\n",
    "        recursion_depth = jnp.ceil(jnp.log(max_steps) / jnp.log(base))\n",
    "        print(f\"Base {base:2d}: result={result}, iterations={counter.count.value}, \"\n",
    "              f\"recursion_depth≈{int(recursion_depth)}\")\n",
    "\n",
    "\n",
    "compare_base_values()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.245463Z",
     "start_time": "2025-10-11T07:36:29.131905Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: 0.0\n",
      "Output: 4085.0\n",
      "Gradient: 0.0\n",
      "\n",
      "bounded_while_loop is differentiable!\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Differentiable bounded_while_loop\n",
    "def smooth_threshold(x, threshold=5.0, lr=0.5, max_steps=20):\n",
    "    \"\"\"Smoothly approach threshold using gradient descent.\"\"\"\n",
    "\n",
    "    def cond_fn(val):\n",
    "        return val < threshold - 0.1\n",
    "\n",
    "    def body(val):\n",
    "        # Gradient of loss = (val - threshold)^2\n",
    "        grad = 2 * (val - threshold)\n",
    "        return val - lr * grad\n",
    "\n",
    "    return bounded_while_loop(cond_fn, body, x, max_steps=max_steps)\n",
    "\n",
    "\n",
    "# Compute gradient\n",
    "x = 0.0\n",
    "value, grad = jax.value_and_grad(smooth_threshold)(x)\n",
    "\n",
    "print(f\"Input: {x}\")\n",
    "print(f\"Output: {value}\")\n",
    "print(f\"Gradient: {grad}\")\n",
    "print(f\"\\nbounded_while_loop is differentiable!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Comparison: `while_loop` vs `bounded_while_loop`\n",
    "\n",
    "**Use `while_loop` when:**\n",
    "- Number of iterations is truly unknown\n",
    "- Not computing gradients\n",
    "- Want standard JAX while loop semantics\n",
    "\n",
    "**Use `bounded_while_loop` when:**\n",
    "- Need gradient computation\n",
    "- Want safety against infinite loops\n",
    "- Can provide reasonable upper bound on iterations\n",
    "- Need predictable compilation characteristics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Conditional Control Flow\n",
    "\n",
    "Conditional primitives enable branching logic that compiles efficiently and handles state properly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 `cond`: Binary Conditional (If/Else)\n",
    "\n",
    "`cond` selectively executes one of two branches based on a boolean predicate. This is the stateful version of `jax.lax.cond`.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "cond(\n",
    "    pred,\n",
    "    true_fun: Callable,\n",
    "    false_fun: Callable,\n",
    "    *operands\n",
    ")\n",
    "```\n",
    "\n",
    "**Parameters:**\n",
    "- `pred`: Boolean scalar (or numeric, where non-zero is True)\n",
    "- `true_fun`: Function called when `pred` is True\n",
    "- `false_fun`: Function called when `pred` is False\n",
    "- `*operands`: Arguments passed to the selected function\n",
    "\n",
    "**Key properties:**\n",
    "- Only the selected branch is executed (lazy evaluation)\n",
    "- Both branches must return the same pytree structure\n",
    "- State modifications in branches are properly tracked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.310518Z",
     "start_time": "2025-10-11T07:36:29.254470Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cond(-5.0 >= 0): 5.0\n",
      "cond(3.0 >= 0): 9.0\n",
      "cond(0.0 >= 0): 0.0\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Simple conditional\n",
    "def positive_branch(x):\n",
    "    return x ** 2\n",
    "\n",
    "\n",
    "def negative_branch(x):\n",
    "    return -x\n",
    "\n",
    "\n",
    "for value in [-5.0, 3.0, 0.0]:\n",
    "    result = cond(value >= 0, positive_branch, negative_branch, value)\n",
    "    print(f\"cond({value} >= 0): {result}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.419593Z",
     "start_time": "2025-10-11T07:36:29.317580Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Values: [ 1. -2.  3. -4.  5.]\n",
      "Results: [ 2. -1.  6. -2. 10.]\n",
      "\n",
      "Branch statistics:\n",
      "  True branch taken: 3 times\n",
      "  False branch taken: 2 times\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Stateful conditional\n",
    "class BranchTracker(brainstate.nn.Module):\n",
    "    \"\"\"Track which branches were taken.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.true_count = brainstate.ShortTermState(jnp.array(0))\n",
    "        self.false_count = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "    def true_branch(self, x):\n",
    "        self.true_count.value = self.true_count.value + 1\n",
    "        return x * 2\n",
    "\n",
    "    def false_branch(self, x):\n",
    "        self.false_count.value = self.false_count.value + 1\n",
    "        return x / 2\n",
    "\n",
    "\n",
    "tracker = BranchTracker()\n",
    "\n",
    "# Test multiple values\n",
    "values = jnp.array([1.0, -2.0, 3.0, -4.0, 5.0])\n",
    "results = []\n",
    "\n",
    "for v in values:\n",
    "    result = cond(v > 0, tracker.true_branch, tracker.false_branch, v)\n",
    "    results.append(result)\n",
    "\n",
    "print(\"Values:\", values)\n",
    "print(\"Results:\", jnp.array(results))\n",
    "print(f\"\\nBranch statistics:\")\n",
    "print(f\"  True branch taken: {tracker.true_count.value} times\")\n",
    "print(f\"  False branch taken: {tracker.false_count.value} times\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.437451Z",
     "start_time": "2025-10-11T07:36:29.428084Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Values: [10.  2. -3. -8.  7. -1.]\n",
      "Classifications: ['large_positive', 'small_positive', 'small_negative', 'large_negative', 'large_positive', 'small_negative']\n",
      "\n",
      "Category counts:\n",
      "  large_positive: 2\n",
      "  small_positive: 1\n",
      "  small_negative: 2\n",
      "  large_negative: 1\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Nested conditionals\n",
    "class Classifier(brainstate.nn.Module):\n",
    "    \"\"\"Classify numbers into categories.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.classification_counts = brainstate.ShortTermState({\n",
    "            'large_positive': jnp.array(0),\n",
    "            'small_positive': jnp.array(0),\n",
    "            'small_negative': jnp.array(0),\n",
    "            'large_negative': jnp.array(0),\n",
    "        })\n",
    "\n",
    "    def classify_positive(self, x):\n",
    "        def large(x):\n",
    "            counts = self.classification_counts.value\n",
    "            counts['large_positive'] = counts['large_positive'] + 1\n",
    "            self.classification_counts.value = counts\n",
    "            return 'large_positive'\n",
    "\n",
    "        def small(x):\n",
    "            counts = self.classification_counts.value\n",
    "            counts['small_positive'] = counts['small_positive'] + 1\n",
    "            self.classification_counts.value = counts\n",
    "            return 'small_positive'\n",
    "\n",
    "        return cond(x > 5.0, large, small, x)\n",
    "\n",
    "    def classify_negative(self, x):\n",
    "        def small(x):\n",
    "            counts = self.classification_counts.value\n",
    "            counts['small_negative'] = counts['small_negative'] + 1\n",
    "            self.classification_counts.value = counts\n",
    "            return 'small_negative'\n",
    "\n",
    "        def large(x):\n",
    "            counts = self.classification_counts.value\n",
    "            counts['large_negative'] = counts['large_negative'] + 1\n",
    "            self.classification_counts.value = counts\n",
    "            return 'large_negative'\n",
    "\n",
    "        return cond(x > -5.0, small, large, x)\n",
    "\n",
    "    def classify(self, x):\n",
    "        return cond(x >= 0, self.classify_positive, self.classify_negative, x)\n",
    "\n",
    "\n",
    "classifier = Classifier()\n",
    "\n",
    "with jax.disable_jit():\n",
    "    test_values = jnp.array([10.0, 2.0, -3.0, -8.0, 7.0, -1.0])\n",
    "    classifications = [classifier.classify(v) for v in test_values]\n",
    "\n",
    "    print(\"Values:\", test_values)\n",
    "    print(\"Classifications:\", classifications)\n",
    "    print(\"\\nCategory counts:\")\n",
    "    for category, count in classifier.classification_counts.value.items():\n",
    "        print(f\"  {category}: {count}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 `switch`: Multi-Way Branching\n",
    "\n",
    "`switch` generalizes `cond` to multiple branches, similar to a switch/case statement.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "switch(\n",
    "    index,\n",
    "    branches: Sequence[Callable],\n",
    "    *operands\n",
    ")\n",
    "```\n",
    "\n",
    "**Parameters:**\n",
    "- `index`: Integer scalar selecting which branch to execute\n",
    "- `branches`: Sequence of callables (at least 1)\n",
    "- `*operands`: Arguments passed to the selected branch\n",
    "\n",
    "**Index handling:**\n",
    "- Out-of-bounds indices are clamped to `[0, len(branches) - 1]`\n",
    "- Negative indices are clamped to 0\n",
    "- Indices >= len(branches) are clamped to len(branches) - 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.586429Z",
     "start_time": "2025-10-11T07:36:29.452943Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Operation 0 on 5.0: 6.0\n",
      "Operation 1 on 5.0: 10.0\n",
      "Operation 2 on 5.0: 25.0\n",
      "Operation 3 on 5.0: -5.0\n",
      "\n",
      "Out of bounds (index=4): -5.0\n",
      "Out of bounds (index=-1): 6.0\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Simple multi-way branch\n",
    "def operation_0(x):\n",
    "    return x + 1\n",
    "\n",
    "\n",
    "def operation_1(x):\n",
    "    return x * 2\n",
    "\n",
    "\n",
    "def operation_2(x):\n",
    "    return x ** 2\n",
    "\n",
    "\n",
    "def operation_3(x):\n",
    "    return -x\n",
    "\n",
    "\n",
    "operations = [operation_0, operation_1, operation_2, operation_3]\n",
    "\n",
    "x = 5.0\n",
    "for i in range(len(operations)):\n",
    "    result = switch(i, operations, x)\n",
    "    print(f\"Operation {i} on {x}: {result}\")\n",
    "\n",
    "# Test clamping\n",
    "print(f\"\\nOut of bounds (index={len(operations)}): {switch(len(operations), operations, x)}\")\n",
    "print(f\"Out of bounds (index={-1}): {switch(-1, operations, x)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.854768Z",
     "start_time": "2025-10-11T07:36:29.597956Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: 2.0\n",
      "\n",
      "ReLU      : 2.0000\n",
      "Sigmoid   : 0.8808\n",
      "Tanh      : 0.9640\n",
      "Softplus  : 2.1269\n",
      "Identity  : 2.0000\n",
      "\n",
      "Usage counts: [1 1 1 1 1]\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Stateful switch - activation function selector\n",
    "class ActivationSelector(brainstate.nn.Module):\n",
    "    \"\"\"Select and apply different activation functions.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.usage_counts = brainstate.ShortTermState(jnp.zeros(5, dtype=jnp.int32))\n",
    "\n",
    "    def _track_usage(self, index):\n",
    "        counts = self.usage_counts.value\n",
    "        counts = counts.at[index].add(1)\n",
    "        self.usage_counts.value = counts\n",
    "\n",
    "    def relu(self, x):\n",
    "        self._track_usage(0)\n",
    "        return jnp.maximum(0, x)\n",
    "\n",
    "    def sigmoid(self, x):\n",
    "        self._track_usage(1)\n",
    "        return 1 / (1 + jnp.exp(-x))\n",
    "\n",
    "    def tanh(self, x):\n",
    "        self._track_usage(2)\n",
    "        return jnp.tanh(x)\n",
    "\n",
    "    def softplus(self, x):\n",
    "        self._track_usage(3)\n",
    "        return jnp.log(1 + jnp.exp(x))\n",
    "\n",
    "    def identity(self, x):\n",
    "        self._track_usage(4)\n",
    "        return x\n",
    "\n",
    "    def apply(self, index, x):\n",
    "        return switch(\n",
    "            index,\n",
    "            [self.relu, self.sigmoid, self.tanh, self.softplus, self.identity],\n",
    "            x\n",
    "        )\n",
    "\n",
    "\n",
    "selector = ActivationSelector()\n",
    "activation_names = ['ReLU', 'Sigmoid', 'Tanh', 'Softplus', 'Identity']\n",
    "\n",
    "# Test all activations\n",
    "test_input = 2.0\n",
    "print(f\"Input: {test_input}\\n\")\n",
    "\n",
    "for i in range(len(activation_names)):\n",
    "    result = selector.apply(i, test_input)\n",
    "    print(f\"{activation_names[i]:10s}: {result:.4f}\")\n",
    "\n",
    "print(f\"\\nUsage counts: {selector.usage_counts.value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.870984Z",
     "start_time": "2025-10-11T07:36:29.859033Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simulation results:\n",
      "Step 0: state=  1.0, policy=aggressive  , action= 2.00, reward=1.00\n",
      "Step 1: state= -0.5, policy=conservative, action=-0.25, reward=0.25\n",
      "Step 2: state=  2.0, policy=aggressive  , action= 4.00, reward=2.00\n",
      "Step 3: state= -1.5, policy=conservative, action=-0.75, reward=0.75\n",
      "Step 4: state=  0.8, policy=random      , action= 0.80, reward=0.24\n",
      "\n",
      "Total reward: 4.24\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Dynamic policy selection\n",
    "class PolicySelector(brainstate.nn.Module):\n",
    "    \"\"\"Select different action policies based on state.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.total_reward = brainstate.ShortTermState(jnp.array(0.0))\n",
    "\n",
    "    def aggressive_policy(self, state):\n",
    "        action = state * 2.0\n",
    "        reward = jnp.abs(action) * 0.5\n",
    "        self.total_reward.value = self.total_reward.value + reward\n",
    "        return {'action': action, 'reward': reward, 'policy': 'aggressive'}\n",
    "\n",
    "    def conservative_policy(self, state):\n",
    "        action = state * 0.5\n",
    "        reward = jnp.abs(action) * 1.0\n",
    "        self.total_reward.value = self.total_reward.value + reward\n",
    "        return {'action': action, 'reward': reward, 'policy': 'conservative'}\n",
    "\n",
    "    def random_policy(self, state):\n",
    "        action = state * 1.0\n",
    "        reward = jnp.abs(action) * 0.3\n",
    "        self.total_reward.value = self.total_reward.value + reward\n",
    "        return {'action': action, 'reward': reward, 'policy': 'random'}\n",
    "\n",
    "    def select_and_act(self, policy_index, state):\n",
    "        return switch(\n",
    "            policy_index,\n",
    "            [self.aggressive_policy, self.conservative_policy, self.random_policy],\n",
    "            state\n",
    "        )\n",
    "\n",
    "\n",
    "policy_selector = PolicySelector()\n",
    "\n",
    "# Simulate decision-making over time\n",
    "states = jnp.array([1.0, -0.5, 2.0, -1.5, 0.8])\n",
    "policies = jnp.array([0, 1, 0, 1, 2], dtype=jnp.int32)  # policy choices\n",
    "\n",
    "with jax.disable_jit():\n",
    "    print(\"Simulation results:\")\n",
    "    for i, (policy_idx, state) in enumerate(zip(policies, states)):\n",
    "        result = policy_selector.select_and_act(policy_idx, state)\n",
    "        print(f\"Step {i}: state={state:5.1f}, policy={result['policy']:12s}, \"\n",
    "              f\"action={result['action']:5.2f}, reward={result['reward']:.2f}\")\n",
    "\n",
    "    print(f\"\\nTotal reward: {policy_selector.total_reward.value:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 `ifelse`: Multi-Condition If/Elif/Else\n",
    "\n",
    "`ifelse` provides a high-level interface for multi-condition branching, similar to Python's if/elif/else.\n",
    "\n",
    "**Function signature:**\n",
    "```python\n",
    "ifelse(\n",
    "    conditions,\n",
    "    branches,\n",
    "    *operands,\n",
    "    check_cond: bool = True\n",
    ")\n",
    "```\n",
    "\n",
    "**Parameters:**\n",
    "- `conditions`: Sequence of boolean predicates (should be mutually exclusive)\n",
    "- `branches`: Sequence of callables (same length as conditions)\n",
    "- `*operands`: Arguments passed to the selected branch\n",
    "- `check_cond`: If True, verify exactly one condition is True\n",
    "\n",
    "**Common pattern:**\n",
    "Make the last condition `True` to create a default/else branch:\n",
    "```python\n",
    "ifelse(\n",
    "    [x > 10, x > 5, True],  # last condition is always True\n",
    "    [large_fn, medium_fn, small_fn],\n",
    "    x\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:29.907384Z",
     "start_time": "2025-10-11T07:36:29.882767Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 15.0 -> large\n",
      "  7.0 -> medium\n",
      "  2.0 -> small\n",
      " 10.5 -> large\n",
      "  5.0 -> small\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Simple if/elif/else\n",
    "def classify_number(x):\n",
    "    def large():\n",
    "        return \"large\"\n",
    "\n",
    "    def medium():\n",
    "        return \"medium\"\n",
    "\n",
    "    def small():\n",
    "        return \"small\"\n",
    "\n",
    "    return ifelse(\n",
    "        [x > 10, jnp.logical_and(x > 5, x <= 10), x <= 5],  # True acts as 'else'\n",
    "        [large, medium, small]\n",
    "    )\n",
    "\n",
    "\n",
    "with jax.disable_jit():\n",
    "    for value in [15.0, 7.0, 2.0, 10.5, 5.0]:\n",
    "        category = classify_number(value)\n",
    "        print(f\"{value:5.1f} -> {category}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:30.278333Z",
     "start_time": "2025-10-11T07:36:29.914735Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Grade distribution:\n",
      "  A: ***\n",
      "  B: ***\n",
      "  C: *\n",
      "  D: *\n",
      "  F: **\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Stateful grade calculator\n",
    "class GradeCalculator(brainstate.nn.Module):\n",
    "    \"\"\"Calculate letter grades and track statistics.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.grade_counts = brainstate.ShortTermState({\n",
    "            'A': jnp.array(0),\n",
    "            'B': jnp.array(0),\n",
    "            'C': jnp.array(0),\n",
    "            'D': jnp.array(0),\n",
    "            'F': jnp.array(0),\n",
    "        })\n",
    "\n",
    "    def _record_grade(self, letter):\n",
    "        counts = self.grade_counts.value\n",
    "        counts[letter] = counts[letter] + 1\n",
    "        self.grade_counts.value = counts\n",
    "\n",
    "    def grade_A(self):\n",
    "        return self._record_grade('A')\n",
    "\n",
    "    def grade_B(self):\n",
    "        return self._record_grade('B')\n",
    "\n",
    "    def grade_C(self):\n",
    "        return self._record_grade('C')\n",
    "\n",
    "    def grade_D(self):\n",
    "        return self._record_grade('D')\n",
    "\n",
    "    def grade_F(self):\n",
    "        return self._record_grade('F')\n",
    "\n",
    "    def calculate_grade(self, score):\n",
    "        return ifelse(\n",
    "            [\n",
    "                score >= 90,\n",
    "                jnp.logical_and(score >= 80, score < 90),\n",
    "                jnp.logical_and(score >= 70, score < 80),\n",
    "                jnp.logical_and(score >= 60, score < 70),\n",
    "                score < 60\n",
    "            ],\n",
    "            [\n",
    "                self.grade_A,\n",
    "                self.grade_B,\n",
    "                self.grade_C,\n",
    "                self.grade_D,\n",
    "                self.grade_F\n",
    "            ]\n",
    "        )\n",
    "\n",
    "\n",
    "calculator = GradeCalculator()\n",
    "\n",
    "# Process student scores\n",
    "scores = jnp.array([95, 87, 76, 82, 59, 91, 68, 45, 88, 93])\n",
    "grades = [calculator.calculate_grade(score) for score in scores]\n",
    "\n",
    "print(\"\\nGrade distribution:\")\n",
    "for letter, count in calculator.grade_counts.value.items():\n",
    "    print(f\"  {letter}: {'*' * int(count)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This tutorial covered all control flow primitives in `brainstate.transform`:\n",
    "\n",
    "### Loop Transformations\n",
    "- **`scan`**: Fundamental loop with carry and outputs\n",
    "  - Use for: Recurrent patterns, accumulation, sequential processing\n",
    "  - Collects outputs at each iteration\n",
    "  - Key params: `reverse`, `unroll`, `pbar`\n",
    "- **`checkpointed_scan`**: Memory-efficient scan with gradient checkpointing\n",
    "  - Use for: Long sequences, memory constraints **during gradient computation**\n",
    "  - **Key benefit**: Stores only O(log_base(n)) checkpoints instead of O(n) activations during backpropagation\n",
    "  - Trades computation (recomputation during backward pass) for memory savings\n",
    "  - Key param: `base` (checkpointing granularity)\n",
    "- **`for_loop`**: Simplified loop without explicit carry\n",
    "  - Use for: Simple iteration, state updates\n",
    "  - **Important**: Return value at **each timestep is saved and stacked** into the final output array\n",
    "  - Variadic inputs, no carry management\n",
    "  - Output shape: stacks results along axis 0 (e.g., scalar→1D, vector→2D)\n",
    "- **`checkpointed_for_loop`**: Memory-efficient for loop with gradient checkpointing\n",
    "  - Combines simplicity of for_loop with memory efficiency **during gradient computation**\n",
    "  - Essential for training with very long sequences\n",
    "  - Same memory benefits as `checkpointed_scan`\n",
    "\n",
    "### While Loops\n",
    "- **`while_loop`**: Dynamic iteration with condition\n",
    "  - Use for: Unknown iteration count, no gradients needed\n",
    "  - Constraint: `cond_fun` must be read-only\n",
    "- **`bounded_while_loop`**: While loop with maximum steps\n",
    "  - Use for: Gradients, safety, predictable compilation\n",
    "  - Key params: `max_steps`, `base`\n",
    "\n",
    "### Conditional Control Flow\n",
    "- **`cond`**: Binary conditional (if/else)\n",
    "  - Use for: Two-way decisions\n",
    "  - Lazy evaluation, state-safe\n",
    "- **`switch`**: Multi-way branching (switch/case)\n",
    "  - Use for: Multiple branches with integer index\n",
    "  - Index clamping for safety\n",
    "- **`ifelse`**: Multi-condition branching (if/elif/else)\n",
    "  - Use for: Complex conditions, default branches\n",
    "  - Use `True` for else branch\n",
    "\n",
    "### Key Principles\n",
    "1. **State Safety**: All APIs properly track state reads and writes\n",
    "2. **Lazy Evaluation**: Conditionals only execute selected branches\n",
    "3. **JAX Compatibility**: Compile to efficient JAX primitives\n",
    "4. **Output Collection**: `for_loop` and `scan` collect outputs at each iteration into the final result\n",
    "5. **Memory Efficiency**: Checkpointed variants save memory **during gradient computation** by storing only checkpoints and recomputing intermediate activations during backpropagation\n",
    "6. **Differentiability**: Most APIs support gradients (except `while_loop`); checkpointed variants are essential for long sequences\n",
    "\n",
    "These primitives enable complex control flow while maintaining BrainState's stateful programming model and JAX's performance benefits."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
