{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 5: Advanced Optimizers and Techniques\n",
    "\n",
    "**Difficulty**: Advanced  \n",
    "**Duration**: 40-50 minutes  \n",
    "**Prerequisites**: Tutorials [3](03_optax_getting_started.ipynb) and [4](04_learning_rate_scheduling.ipynb) completion\n",
    "\n",
    "## Learning Objectives\n",
    "- Use specialized optimizers for specific scenarios\n",
    "- Implement second-order optimization methods\n",
    "- Apply gradient-free optimization\n",
    "- Understand memory-efficient optimizers\n",
    "\n",
    "## Topics Covered\n",
    "1. **Specialized gradient-based optimizers**\n",
    "   - Lion: Memory-efficient optimizer\n",
    "   - Adafactor: Factorized second moments\n",
    "   - Lookahead: k-step forward optimization\n",
    "   - RAdam: Rectified Adam\n",
    "\n",
    "2. **Large-scale training optimizers**\n",
    "   - LAMB: Layer-wise adaptive large batch\n",
    "   - LARS: Layer-wise adaptive rate scaling\n",
    "   - SM3: Memory-efficient for large models\n",
    "\n",
    "3. **Alternative optimization paradigms**\n",
    "   - LBFGS: Quasi-Newton method\n",
    "   - Rprop: Resilient backpropagation\n",
    "   - Yogi: Additive adaptive methods\n",
    "\n",
    "4. **Gradient-free optimization**\n",
    "   - NevergradOptimizer integration\n",
    "   - ScipyOptimizer for constrained problems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import brainstate\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "import braintools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setting up Test Models and Data\n",
    "\n",
    "We'll create different model architectures to test various optimizer characteristics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(brainstate.nn.Module):\n",
    "    \"\"\"Simplified Transformer block for testing large-scale optimizers.\"\"\"\n",
    "\n",
    "    def __init__(self, dim=512, num_heads=8, mlp_ratio=4.0):\n",
    "        super().__init__()\n",
    "        self.dim = dim\n",
    "        self.num_heads = num_heads\n",
    "\n",
    "        # Multi-head attention components\n",
    "        self.qkv = brainstate.nn.Linear(dim, dim * 3)\n",
    "        self.proj = brainstate.nn.Linear(dim, dim)\n",
    "\n",
    "        # MLP components\n",
    "        mlp_hidden_dim = int(dim * mlp_ratio)\n",
    "        self.fc1 = brainstate.nn.Linear(dim, mlp_hidden_dim)\n",
    "        self.fc2 = brainstate.nn.Linear(mlp_hidden_dim, dim)\n",
    "\n",
    "        # Layer norms\n",
    "        self.norm1 = brainstate.nn.LayerNorm(dim)\n",
    "        self.norm2 = brainstate.nn.LayerNorm(dim)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        # Simplified attention (without actual attention computation)\n",
    "        residual = x\n",
    "        x = self.norm1(x)\n",
    "\n",
    "        # QKV projection\n",
    "        qkv = self.qkv(x)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=-1)\n",
    "\n",
    "        # Simplified attention output (just use v for demonstration)\n",
    "        attn_output = self.proj(v)\n",
    "        x = residual + attn_output\n",
    "\n",
    "        # MLP block\n",
    "        residual = x\n",
    "        x = self.norm2(x)\n",
    "        x = self.fc1(x)\n",
    "        x = jax.nn.gelu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = residual + x\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class CNNModel(brainstate.nn.Module):\n",
    "    \"\"\"CNN for testing memory-efficient optimizers.\"\"\"\n",
    "\n",
    "    def __init__(self, in_size, num_classes=10):\n",
    "        super().__init__()\n",
    "        # Conv layers\n",
    "        self.conv1 = brainstate.nn.Conv2d(in_size, 64, kernel_size=3, padding=1)\n",
    "        self.pool1 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv1.out_size)\n",
    "        self.conv2 = brainstate.nn.Conv2d(self.pool1.out_size, 128, kernel_size=3, padding=1)\n",
    "        self.pool2 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv2.out_size)\n",
    "        self.conv3 = brainstate.nn.Conv2d(self.pool2.out_size, 256, kernel_size=3, padding=1)\n",
    "        self.pool3 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv3.out_size)\n",
    "\n",
    "        # Dense layers\n",
    "        self.fc1 = brainstate.nn.Linear(int(np.prod(self.pool3.out_size)), 512)\n",
    "        self.fc2 = brainstate.nn.Linear(512, num_classes)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        # Reshape if needed\n",
    "        if len(x.shape) == 2:\n",
    "            x = x.reshape(-1, 32, 32, 3)\n",
    "\n",
    "        # Conv blocks\n",
    "        x = self.conv1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.pool1(x)\n",
    "\n",
    "        x = self.conv2(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.pool2(x)\n",
    "\n",
    "        x = self.conv3(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.pool3(x)\n",
    "\n",
    "        # Flatten and FC layers\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = self.fc1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.fc2(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class SimpleRNN(brainstate.nn.Module):\n",
    "    \"\"\"Simple RNN for testing gradient stability.\"\"\"\n",
    "\n",
    "    def __init__(self, input_dim=10, hidden_dim=128, output_dim=10):\n",
    "        super().__init__()\n",
    "        self.rnn = brainstate.nn.ValinaRNNCell(input_dim, hidden_dim, num_layers=2)\n",
    "        self.fc = brainstate.nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        # x shape: (batch, seq_len, features)\n",
    "        outputs = brainstate.transform.for_loop(self.rnn, x)\n",
    "        # Use last timestep\n",
    "        return self.fc(outputs[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_synthetic_data(data_type='vision', n_samples=1000, seed=42):\n",
    "    \"\"\"Create synthetic data for different model types.\"\"\"\n",
    "    with brainstate.random.seed_context(seed):\n",
    "        if data_type == 'vision':\n",
    "            # Image-like data (32x32x3)\n",
    "            X = brainstate.random.normal(size=(n_samples, 32, 32, 3)) * 0.5\n",
    "            y = brainstate.random.randint(0, 10, size=(n_samples,))\n",
    "        elif data_type == 'transformer':\n",
    "            # Sequence data for transformer (seq_len=64, dim=512)\n",
    "            X = brainstate.random.normal(size=(n_samples, 64, 512)) * 0.1\n",
    "            y = brainstate.random.randint(0, 10, size=(n_samples,))\n",
    "        elif data_type == 'sequence':\n",
    "            # Sequence data for RNN (seq_len=20, features=10)\n",
    "            X = brainstate.random.normal(size=(n_samples, 20, 10)) * 0.5\n",
    "            y = brainstate.random.randint(0, 10, size=(n_samples,))\n",
    "        else:\n",
    "            # Default: flat features\n",
    "            X = brainstate.random.normal(size=(n_samples, 784)) * 0.5\n",
    "            y = brainstate.random.randint(0, 10, size=(n_samples,))\n",
    "\n",
    "    return X, y\n",
    "\n",
    "\n",
    "# Create datasets\n",
    "X_vision, y_vision = create_synthetic_data('vision', n_samples=2000)\n",
    "X_transformer, y_transformer = create_synthetic_data('transformer', n_samples=1000)\n",
    "X_sequence, y_sequence = create_synthetic_data('sequence', n_samples=2000)\n",
    "\n",
    "print(f\"Vision data shape: {X_vision.shape}\")\n",
    "print(f\"Transformer data shape: {X_transformer.shape}\")\n",
    "print(f\"Sequence data shape: {X_sequence.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Gradient Computation and Training Infrastructure\n",
    "\n",
    "Following the style from previous tutorials, we'll set up our gradient computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_loss_and_grads(model, X, y, param_states, loss_type='classification'):\n",
    "    \"\"\"Compute loss and gradients following braintools style.\"\"\"\n",
    "\n",
    "    def loss_fn():\n",
    "        # Forward pass\n",
    "        outputs = model(X)\n",
    "\n",
    "        if loss_type == 'classification':\n",
    "            # Cross-entropy loss\n",
    "            log_probs = jax.nn.log_softmax(outputs, axis=-1)\n",
    "            one_hot = jax.nn.one_hot(y, num_classes=10)\n",
    "            loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))\n",
    "        else:\n",
    "            # MSE loss for regression\n",
    "            loss = jnp.mean((outputs - y) ** 2)\n",
    "\n",
    "        # Add L2 regularization\n",
    "        l2_reg = 1e-4\n",
    "        for state in param_states.values():\n",
    "            loss = loss + l2_reg * jnp.sum(state.value ** 2)\n",
    "\n",
    "        return loss\n",
    "\n",
    "    # Compute loss and gradients\n",
    "    loss = loss_fn()\n",
    "    grads = brainstate.transform.grad(loss_fn, grad_states=param_states)()\n",
    "\n",
    "    # Compute accuracy for classification\n",
    "    if loss_type == 'classification':\n",
    "        outputs = model(X)\n",
    "        predictions = jnp.argmax(outputs, axis=-1)\n",
    "        accuracy = jnp.mean(predictions == y)\n",
    "    else:\n",
    "        accuracy = -loss  # Use negative loss as metric for regression\n",
    "\n",
    "    return loss, grads, accuracy\n",
    "\n",
    "\n",
    "def train_with_optimizer(\n",
    "    model: brainstate.nn.Module,\n",
    "    optimizer: braintools.optim.OptaxOptimizer,\n",
    "    X_train, y_train,\n",
    "    X_val, y_val,\n",
    "    n_epochs=30,\n",
    "    batch_size=64,\n",
    "    verbose=True\n",
    "):\n",
    "    \"\"\"Generic training function for any optimizer.\"\"\"\n",
    "\n",
    "    # Get parameter states\n",
    "    param_states = braintools.optim.UniqueStateManager(\n",
    "        model.states(brainstate.ParamState)\n",
    "    ).to_pytree()\n",
    "\n",
    "    # Register parameters with optimizer\n",
    "    optimizer.register_trainable_weights(param_states)\n",
    "\n",
    "    @brainstate.transform.jit\n",
    "    def train_step(X_batch, y_batch):\n",
    "        loss, grads, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)\n",
    "        optimizer.update(grads)\n",
    "        return loss, acc\n",
    "\n",
    "    @brainstate.transform.jit\n",
    "    def eval_step(X_batch, y_batch):\n",
    "        loss, _, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)\n",
    "        return loss, acc\n",
    "\n",
    "    history = {\n",
    "        'train_loss': [],\n",
    "        'train_acc': [],\n",
    "        'val_loss': [],\n",
    "        'val_acc': [],\n",
    "        'epoch_time': []\n",
    "    }\n",
    "\n",
    "    n_batches = len(X_train) // batch_size\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        epoch_start = time.time()\n",
    "\n",
    "        # Shuffle data\n",
    "        perm = brainstate.random.permutation(len(X_train))\n",
    "        X_train_shuffled = X_train[perm]\n",
    "        y_train_shuffled = y_train[perm]\n",
    "\n",
    "        train_losses = []\n",
    "        train_accs = []\n",
    "\n",
    "        for batch_idx in range(n_batches):\n",
    "            start_idx = batch_idx * batch_size\n",
    "            end_idx = start_idx + batch_size\n",
    "            X_batch = X_train_shuffled[start_idx:end_idx]\n",
    "            y_batch = y_train_shuffled[start_idx:end_idx]\n",
    "\n",
    "            loss, acc = train_step(X_batch, y_batch)\n",
    "            train_losses.append(float(loss))\n",
    "            train_accs.append(float(acc))\n",
    "\n",
    "        # Validation\n",
    "        val_loss, val_acc = eval_step(X_val[:500], y_val[:500])  # Use subset for speed\n",
    "\n",
    "        # Update learning rate if scheduler is attached\n",
    "        optimizer.lr.step()\n",
    "\n",
    "        # Record metrics\n",
    "        history['train_loss'].append(np.mean(train_losses))\n",
    "        history['train_acc'].append(np.mean(train_accs))\n",
    "        history['val_loss'].append(float(val_loss))\n",
    "        history['val_acc'].append(float(val_acc))\n",
    "        history['epoch_time'].append(time.time() - epoch_start)\n",
    "\n",
    "        if verbose and (epoch + 1) % 10 == 0:\n",
    "            print(f\"Epoch {epoch + 1}/{n_epochs} - \"\n",
    "                  f\"Loss: {history['train_loss'][-1]:.4f}, \"\n",
    "                  f\"Acc: {history['train_acc'][-1]:.4f}, \"\n",
    "                  f\"Val Loss: {history['val_loss'][-1]:.4f}, \"\n",
    "                  f\"Val Acc: {history['val_acc'][-1]:.4f}\")\n",
    "\n",
    "    return history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Specialized Gradient-Based Optimizers\n",
    "\n",
    "Let's explore advanced optimizers designed for specific scenarios."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Lion Optimizer - Memory Efficient\n",
    "\n",
    "Lion (EvoLved Sign Momentum) is a memory-efficient optimizer that uses sign updates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lion optimizer\n",
    "model_lion = CNNModel()\n",
    "\n",
    "lion_optimizer = braintools.optim.Lion(\n",
    "    lr=3e-4,  # Lion typically uses smaller learning rates\n",
    "    betas=(0.9, 0.99),\n",
    "    weight_decay=1e-4\n",
    ")\n",
    "\n",
    "print(\"Training with Lion optimizer (memory-efficient)...\")\n",
    "history_lion = train_with_optimizer(\n",
    "    model_lion, lion_optimizer,\n",
    "    X_vision[:1000], y_vision[:1000],\n",
    "    X_vision[1000:1500], y_vision[1000:1500],\n",
    "    n_epochs=30, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Adafactor - Factorized Second Moments\n",
    "\n",
    "Adafactor reduces memory usage by factorizing the second moment estimation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adafactor optimizer\n",
    "model_adafactor = TransformerBlock()\n",
    "\n",
    "adafactor_optimizer = braintools.optim.Adafactor(\n",
    "    lr=1e-3,\n",
    "    decay_rate=0.8,\n",
    "    factored=True,  # Enable factorization for memory efficiency\n",
    "    clip_threshold=1.0\n",
    ")\n",
    "\n",
    "print(\"Training with Adafactor (factorized second moments)...\")\n",
    "history_adafactor = train_with_optimizer(\n",
    "    model_adafactor, adafactor_optimizer,\n",
    "    X_transformer[:500], y_transformer[:500],\n",
    "    X_transformer[500:700], y_transformer[500:700],\n",
    "    n_epochs=30, batch_size=16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Lookahead Optimizer - k-step Forward\n",
    "\n",
    "Lookahead maintains two sets of weights and performs k-step forward optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Lookahead optimizer wrapping SGD\n",
    "model_lookahead = CNNModel()\n",
    "\n",
    "# Base optimizer\n",
    "base_optimizer = braintools.optim.SGD(lr=0.1, momentum=0.9)\n",
    "\n",
    "# Wrap with Lookahead\n",
    "lookahead_optimizer = braintools.optim.Lookahead(\n",
    "    base_optimizer,\n",
    "    sync_period=5,  # Update slow weights every 5 steps\n",
    "    alpha=0.5  # Interpolation factor\n",
    ")\n",
    "\n",
    "print(\"Training with Lookahead optimizer (k-step forward)...\")\n",
    "history_lookahead = train_with_optimizer(\n",
    "    model_lookahead, lookahead_optimizer,\n",
    "    X_vision[:1000], y_vision[:1000],\n",
    "    X_vision[1000:1500], y_vision[1000:1500],\n",
    "    n_epochs=30, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 RAdam - Rectified Adam\n",
    "\n",
    "RAdam rectifies the variance of the adaptive learning rate to stabilize training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# RAdam optimizer\n",
    "model_radam = SimpleRNN()\n",
    "\n",
    "radam_optimizer = braintools.optim.RAdam(\n",
    "    lr=1e-3,\n",
    "    betas=(0.9, 0.999),\n",
    "    eps=1e-8,\n",
    "    weight_decay=1e-4\n",
    ")\n",
    "\n",
    "print(\"Training with RAdam (Rectified Adam)...\")\n",
    "history_radam = train_with_optimizer(\n",
    "    model_radam, radam_optimizer,\n",
    "    X_sequence[:1000], y_sequence[:1000],\n",
    "    X_sequence[1000:1500], y_sequence[1000:1500],\n",
    "    n_epochs=30, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Large-Scale Training Optimizers\n",
    "\n",
    "These optimizers are designed for training with large batch sizes and distributed settings."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 LAMB - Layer-wise Adaptive Large Batch\n",
    "\n",
    "LAMB enables large batch training by adapting the learning rate per layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LAMB optimizer for large batch training\n",
    "model_lamb = TransformerBlock()\n",
    "\n",
    "lamb_optimizer = braintools.optim.Lamb(\n",
    "    lr=2e-3,\n",
    "    betas=(0.9, 0.999),\n",
    "    eps=1e-6,\n",
    "    weight_decay=0.01,\n",
    "    grad_clip_value=10.0  # Gradient clipping\n",
    ")\n",
    "\n",
    "print(\"Training with LAMB (Large Batch optimizer)...\")\n",
    "# Simulate large batch by using larger batch size\n",
    "history_lamb = train_with_optimizer(\n",
    "    model_lamb, lamb_optimizer,\n",
    "    X_transformer[:800], y_transformer[:800],\n",
    "    X_transformer[800:], y_transformer[800:],\n",
    "    n_epochs=30, batch_size=128  # Large batch size\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 LARS - Layer-wise Adaptive Rate Scaling\n",
    "\n",
    "LARS adapts the learning rate for each layer based on the ratio of weight and gradient norms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LARS optimizer\n",
    "model_lars = CNNModel()\n",
    "\n",
    "lars_optimizer = braintools.optim.Lars(\n",
    "    lr=0.1,\n",
    "    momentum=0.9,\n",
    "    weight_decay=1e-4,\n",
    "    trust_coefficient=0.001,  # LARS-specific parameter\n",
    "    eps=1e-8\n",
    ")\n",
    "\n",
    "print(\"Training with LARS (Layer-wise Adaptive Rate Scaling)...\")\n",
    "history_lars = train_with_optimizer(\n",
    "    model_lars, lars_optimizer,\n",
    "    X_vision[:1000], y_vision[:1000],\n",
    "    X_vision[1000:1500], y_vision[1000:1500],\n",
    "    n_epochs=30, batch_size=128\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 SM3 - Memory-Efficient for Large Models\n",
    "\n",
    "SM3 uses a memory-efficient approximation of adaptive learning rates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SM3 optimizer for memory efficiency\n",
    "model_sm3 = TransformerBlock()\n",
    "\n",
    "sm3_optimizer = braintools.optim.SM3(\n",
    "    lr=1e-3,\n",
    "    momentum=0.9,\n",
    "    eps=1e-8\n",
    ")\n",
    "\n",
    "print(\"Training with SM3 (Memory-efficient optimizer)...\")\n",
    "history_sm3 = train_with_optimizer(\n",
    "    model_sm3, sm3_optimizer,\n",
    "    X_transformer[:500], y_transformer[:500],\n",
    "    X_transformer[500:700], y_transformer[500:700],\n",
    "    n_epochs=30, batch_size=16\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Alternative Optimization Paradigms\n",
    "\n",
    "These optimizers use different principles than standard gradient descent."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 L-BFGS - Quasi-Newton Method\n",
    "\n",
    "L-BFGS approximates the Hessian matrix for second-order optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# L-BFGS optimizer (Note: requires special handling)\n",
    "from brainstate.nn import Linear\n",
    "\n",
    "\n",
    "class SimpleMLP(brainstate.nn.Module):\n",
    "    \"\"\"Simple MLP for L-BFGS testing.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = Linear(784, 128)\n",
    "        self.fc2 = Linear(128, 10)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        x = x.reshape(x.shape[0], -1)\n",
    "        x = self.fc1(x)\n",
    "        x = jax.nn.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "model_lbfgs = SimpleMLP()\n",
    "\n",
    "# L-BFGS requires full-batch training\n",
    "lbfgs_optimizer = braintools.optim.LBFGS(\n",
    "    lr=1.0,\n",
    "    memory_size=10,\n",
    "    line_search_fn='zoom'\n",
    ")\n",
    "\n",
    "print(\"Training with L-BFGS (Quasi-Newton method)...\")\n",
    "# Note: L-BFGS typically works better with full-batch\n",
    "X_small = X_vision[:200].reshape(200, -1)\n",
    "y_small = y_vision[:200]\n",
    "X_val_small = X_vision[1000:1100].reshape(100, -1)\n",
    "y_val_small = y_vision[1000:1100]\n",
    "\n",
    "history_lbfgs = train_with_optimizer(\n",
    "    model_lbfgs, lbfgs_optimizer,\n",
    "    X_small, y_small,\n",
    "    X_val_small, y_val_small,\n",
    "    n_epochs=20, batch_size=200  # Full batch\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Rprop - Resilient Backpropagation\n",
    "\n",
    "Rprop uses only the sign of the gradient and adapts step sizes individually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Rprop optimizer\n",
    "model_rprop = SimpleMLP()\n",
    "\n",
    "rprop_optimizer = braintools.optim.Rprop(\n",
    "    lr=1e-3,\n",
    "    etas=(0.5, 1.2),  # Step size adaptation factors\n",
    "    step_sizes=(1e-6, 50)  # Min and max step sizes\n",
    ")\n",
    "\n",
    "print(\"Training with Rprop (Resilient Backpropagation)...\")\n",
    "history_rprop = train_with_optimizer(\n",
    "    model_rprop, rprop_optimizer,\n",
    "    X_small, y_small,\n",
    "    X_val_small, y_val_small,\n",
    "    n_epochs=30, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 Yogi - Additive Adaptive Methods\n",
    "\n",
    "Yogi uses additive updates instead of multiplicative for better convergence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Yogi optimizer\n",
    "model_yogi = CNNModel()\n",
    "\n",
    "yogi_optimizer = braintools.optim.Yogi(\n",
    "    lr=1e-2,\n",
    "    betas=(0.9, 0.999),\n",
    "    eps=1e-3  # Yogi typically uses larger epsilon\n",
    ")\n",
    "\n",
    "print(\"Training with Yogi (Additive adaptive method)...\")\n",
    "history_yogi = train_with_optimizer(\n",
    "    model_yogi, yogi_optimizer,\n",
    "    X_vision[:1000], y_vision[:1000],\n",
    "    X_vision[1000:1500], y_vision[1000:1500],\n",
    "    n_epochs=30, batch_size=32\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Comparing Optimizer Performance\n",
    "\n",
    "Let's visualize and compare the performance of different optimizer categories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_optimizer_comparison(histories, names, title=\"Optimizer Comparison\"):\n",
    "    \"\"\"Create comprehensive comparison plots.\"\"\"\n",
    "\n",
    "    fig = plt.figure(figsize=(16, 10))\n",
    "    gs = GridSpec(3, 3, figure=fig)\n",
    "\n",
    "    # Define color palette\n",
    "    colors = plt.cm.tab10(np.linspace(0, 1, len(histories)))\n",
    "\n",
    "    # Training loss\n",
    "    ax1 = fig.add_subplot(gs[0, 0])\n",
    "    for hist, name, color in zip(histories, names, colors):\n",
    "        ax1.plot(hist['train_loss'], label=name, color=color, linewidth=2)\n",
    "    ax1.set_xlabel('Epoch')\n",
    "    ax1.set_ylabel('Training Loss')\n",
    "    ax1.set_title('Training Loss')\n",
    "    ax1.legend(fontsize=8)\n",
    "    ax1.grid(True, alpha=0.3)\n",
    "\n",
    "    # Validation loss\n",
    "    ax2 = fig.add_subplot(gs[0, 1])\n",
    "    for hist, name, color in zip(histories, names, colors):\n",
    "        ax2.plot(hist['val_loss'], label=name, color=color, linewidth=2)\n",
    "    ax2.set_xlabel('Epoch')\n",
    "    ax2.set_ylabel('Validation Loss')\n",
    "    ax2.set_title('Validation Loss')\n",
    "    ax2.legend(fontsize=8)\n",
    "    ax2.grid(True, alpha=0.3)\n",
    "\n",
    "    # Training accuracy\n",
    "    ax3 = fig.add_subplot(gs[0, 2])\n",
    "    for hist, name, color in zip(histories, names, colors):\n",
    "        ax3.plot(hist['train_acc'], label=name, color=color, linewidth=2)\n",
    "    ax3.set_xlabel('Epoch')\n",
    "    ax3.set_ylabel('Training Accuracy')\n",
    "    ax3.set_title('Training Accuracy')\n",
    "    ax3.legend(fontsize=8)\n",
    "    ax3.grid(True, alpha=0.3)\n",
    "\n",
    "    # Convergence speed (loss reduction)\n",
    "    ax4 = fig.add_subplot(gs[1, 0])\n",
    "    for hist, name, color in zip(histories, names, colors):\n",
    "        loss_reduction = np.array(hist['train_loss']) / hist['train_loss'][0]\n",
    "        ax4.plot(loss_reduction, label=name, color=color, linewidth=2)\n",
    "    ax4.set_xlabel('Epoch')\n",
    "    ax4.set_ylabel('Loss Reduction Ratio')\n",
    "    ax4.set_title('Convergence Speed')\n",
    "    ax4.legend(fontsize=8)\n",
    "    ax4.grid(True, alpha=0.3)\n",
    "\n",
    "    # Training time per epoch\n",
    "    ax5 = fig.add_subplot(gs[1, 1])\n",
    "    avg_times = [np.mean(hist['epoch_time']) for hist in histories]\n",
    "    bars = ax5.bar(range(len(names)), avg_times, color=colors)\n",
    "    ax5.set_xticks(range(len(names)))\n",
    "    ax5.set_xticklabels(names, rotation=45, ha='right')\n",
    "    ax5.set_ylabel('Average Time per Epoch (s)')\n",
    "    ax5.set_title('Training Efficiency')\n",
    "    ax5.grid(True, alpha=0.3, axis='y')\n",
    "\n",
    "    # Final performance comparison\n",
    "    ax6 = fig.add_subplot(gs[1, 2])\n",
    "    final_train_loss = [hist['train_loss'][-1] for hist in histories]\n",
    "    final_val_loss = [hist['val_loss'][-1] for hist in histories]\n",
    "\n",
    "    x = np.arange(len(names))\n",
    "    width = 0.35\n",
    "\n",
    "    bars1 = ax6.bar(x - width / 2, final_train_loss, width, label='Train Loss', color='steelblue')\n",
    "    bars2 = ax6.bar(x + width / 2, final_val_loss, width, label='Val Loss', color='coral')\n",
    "\n",
    "    ax6.set_xticks(x)\n",
    "    ax6.set_xticklabels(names, rotation=45, ha='right')\n",
    "    ax6.set_ylabel('Final Loss')\n",
    "    ax6.set_title('Final Performance')\n",
    "    ax6.legend()\n",
    "    ax6.grid(True, alpha=0.3, axis='y')\n",
    "\n",
    "    # Loss landscape smoothness (variance of loss)\n",
    "    ax7 = fig.add_subplot(gs[2, 0])\n",
    "    for hist, name, color in zip(histories, names, colors):\n",
    "        # Calculate rolling variance\n",
    "        window = 5\n",
    "        loss_array = np.array(hist['train_loss'])\n",
    "        if len(loss_array) >= window:\n",
    "            rolling_var = np.convolve(\n",
    "                (loss_array - np.mean(loss_array)) ** 2,\n",
    "                np.ones(window) / window,\n",
    "                mode='valid'\n",
    "            )\n",
    "            ax7.plot(rolling_var, label=name, color=color, linewidth=2)\n",
    "    ax7.set_xlabel('Epoch')\n",
    "    ax7.set_ylabel('Loss Variance')\n",
    "    ax7.set_title('Training Stability')\n",
    "    ax7.legend(fontsize=8)\n",
    "    ax7.grid(True, alpha=0.3)\n",
    "\n",
    "    # Memory usage estimate (simplified)\n",
    "    ax8 = fig.add_subplot(gs[2, 1:])  # Span two columns\n",
    "\n",
    "    # Optimizer memory footprint (relative estimates)\n",
    "    memory_factors = {\n",
    "        'Lion': 0.5,  # Sign-based, very memory efficient\n",
    "        'Adafactor': 0.6,  # Factorized moments\n",
    "        'SM3': 0.7,  # Sparse second moments\n",
    "        'Rprop': 0.8,  # Only step sizes\n",
    "        'SGD': 0.9,  # Momentum only\n",
    "        'Adam': 1.0,  # Baseline (first and second moments)\n",
    "        'RAdam': 1.0,  # Same as Adam\n",
    "        'Yogi': 1.0,  # Similar to Adam\n",
    "        'Lookahead': 1.5,  # Two sets of weights\n",
    "        'LAMB': 1.2,  # Layer-wise adaptation\n",
    "        'LARS': 1.1,  # Layer-wise scaling\n",
    "        'L-BFGS': 2.0,  # History of gradients\n",
    "    }\n",
    "\n",
    "    mem_values = [memory_factors.get(name, 1.0) for name in names]\n",
    "    bars = ax8.barh(range(len(names)), mem_values, color=colors)\n",
    "    ax8.set_yticks(range(len(names)))\n",
    "    ax8.set_yticklabels(names)\n",
    "    ax8.set_xlabel('Relative Memory Usage')\n",
    "    ax8.set_title('Memory Efficiency Comparison')\n",
    "    ax8.grid(True, alpha=0.3, axis='x')\n",
    "\n",
    "    plt.suptitle(title, fontsize=16, fontweight='bold')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Compare specialized optimizers\n",
    "specialized_histories = [history_lion, history_adafactor, history_radam, history_yogi]\n",
    "specialized_names = ['Lion', 'Adafactor', 'RAdam', 'Yogi']\n",
    "\n",
    "plot_optimizer_comparison(\n",
    "    specialized_histories,\n",
    "    specialized_names,\n",
    "    \"Specialized Gradient-Based Optimizers\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare large-scale optimizers\n",
    "largescale_histories = [history_lamb, history_lars, history_sm3]\n",
    "largescale_names = ['LAMB', 'LARS', 'SM3']\n",
    "\n",
    "plot_optimizer_comparison(\n",
    "    largescale_histories,\n",
    "    largescale_names,\n",
    "    \"Large-Scale Training Optimizers\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Gradient-Free Optimization\n",
    "\n",
    "For gradient-free optimization, braintools provides integration with specialized libraries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.1 Nevergrad Integration\n",
    "\n",
    "Nevergrad provides a wide range of gradient-free optimization algorithms, please refer to the [nevergrad tutorial documentation](01_nevergrad_optimizer.ipynb) for details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7.2 SciPy Optimization\n",
    "\n",
    "SciPy provides classical optimization algorithms including constrained optimization, please refer to the [scipy tutorial documentation](02_scipy_optimizer.ipynb) for details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary and Best Practices\n",
    "\n",
    "**Key Takeaways**\n",
    "\n",
    "1. **Memory-Efficient Optimizers**\n",
    "   - **Lion**: Best for very large models with memory constraints\n",
    "   - **Adafactor**: Good balance of memory and performance\n",
    "   - **SM3**: Excellent for sparse models\n",
    "\n",
    "2. **Large-Scale Training**\n",
    "   - **LAMB/LARS**: Essential for large batch training\n",
    "   - Enable linear scaling of batch size with learning rate\n",
    "   - Critical for distributed training\n",
    "\n",
    "3. **Stability and Robustness**\n",
    "   - **RAdam**: Rectified variance for stability\n",
    "   - **Lookahead**: Reduces variance through averaging\n",
    "   - **Yogi**: Additive updates for better convergence\n",
    "\n",
    "4. **Alternative Paradigms**\n",
    "   - **L-BFGS**: Excellent for small datasets with second-order information\n",
    "   - **Rprop**: Robust to gradient noise\n",
    "   - **Gradient-free**: When gradients are unavailable or unreliable\n",
    "\n",
    "**When to Use Advanced Optimizers**\n",
    "\n",
    "| Scenario | Recommended Optimizer | Reason |\n",
    "|----------|----------------------|--------|\n",
    "| Large Language Models | Lion, Adafactor | Memory efficiency |\n",
    "| Distributed Training | LAMB, LARS | Large batch handling |\n",
    "| Noisy Gradients | RAdam, Lookahead | Stability |\n",
    "| Small Dataset | L-BFGS | Fast convergence |\n",
    "| Research/Experimentation | Yogi, Custom | Novel behaviors |\n",
    "| Constrained Optimization | ScipyOptimizer | Built-in constraints |\n",
    "| Black-box Optimization | NevergradOptimizer | No gradients needed |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "1. **Memory Comparison**: Train the same large model with Adam, Lion, and Adafactor. Monitor and compare memory usage.\n",
    "\n",
    "2. **Large Batch Scaling**: Test how well different optimizers handle increasing batch sizes from 32 to 1024.\n",
    "\n",
    "3. **Stability Analysis**: Add artificial noise to gradients and compare optimizer robustness.\n",
    "\n",
    "4. **Hybrid Approach**: Implement a training schedule that switches optimizers (e.g., Adam → L-BFGS for fine-tuning).\n",
    "\n",
    "5. **Custom Optimizer**: Create your own optimizer by combining ideas from different methods.\n",
    "\n",
    "6. **Constraint Satisfaction**: Use ScipyOptimizer to solve a constrained optimization problem in neural network training.\n",
    "\n",
    "7. **Hyperparameter Optimization**: Use NevergradOptimizer to tune the hyperparameters of another optimizer."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
