Tutorial 5: Advanced Optimizers and Techniques#

Difficulty: Advanced
Duration: 40-50 minutes
Prerequisites: Tutorials 3 and 4 completion

Learning Objectives#

  • Use specialized optimizers for specific scenarios

  • Implement second-order optimization methods

  • Apply gradient-free optimization

  • Understand memory-efficient optimizers

Topics Covered#

  1. Specialized gradient-based optimizers

    • Lion: Memory-efficient optimizer

    • Adafactor: Factorized second moments

    • Lookahead: k-step forward optimization

    • RAdam: Rectified Adam

  2. Large-scale training optimizers

    • LAMB: Layer-wise adaptive large batch

    • LARS: Layer-wise adaptive rate scaling

    • SM3: Memory-efficient for large models

  3. Alternative optimization paradigms

    • LBFGS: Quasi-Newton method

    • Rprop: Resilient backpropagation

    • Yogi: Additive adaptive methods

  4. Gradient-free optimization

    • NevergradOptimizer integration

    • ScipyOptimizer for constrained problems

import time

import brainstate
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec

import braintools

1. Setting up Test Models and Data#

We’ll create different model architectures to test various optimizer characteristics.

class TransformerBlock(brainstate.nn.Module):
    """Simplified Transformer block for testing large-scale optimizers."""

    def __init__(self, dim=512, num_heads=8, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads

        # Multi-head attention components
        self.qkv = brainstate.nn.Linear(dim, dim * 3)
        self.proj = brainstate.nn.Linear(dim, dim)

        # MLP components
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.fc1 = brainstate.nn.Linear(dim, mlp_hidden_dim)
        self.fc2 = brainstate.nn.Linear(mlp_hidden_dim, dim)

        # Layer norms
        self.norm1 = brainstate.nn.LayerNorm(dim)
        self.norm2 = brainstate.nn.LayerNorm(dim)

    def __call__(self, x):
        # Simplified attention (without actual attention computation)
        residual = x
        x = self.norm1(x)

        # QKV projection
        qkv = self.qkv(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        # Simplified attention output (just use v for demonstration)
        attn_output = self.proj(v)
        x = residual + attn_output

        # MLP block
        residual = x
        x = self.norm2(x)
        x = self.fc1(x)
        x = jax.nn.gelu(x)
        x = self.fc2(x)
        x = residual + x

        return x


class CNNModel(brainstate.nn.Module):
    """CNN for testing memory-efficient optimizers."""

    def __init__(self, in_size, num_classes=10):
        super().__init__()
        # Conv layers
        self.conv1 = brainstate.nn.Conv2d(in_size, 64, kernel_size=3, padding=1)
        self.pool1 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv1.out_size)
        self.conv2 = brainstate.nn.Conv2d(self.pool1.out_size, 128, kernel_size=3, padding=1)
        self.pool2 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv2.out_size)
        self.conv3 = brainstate.nn.Conv2d(self.pool2.out_size, 256, kernel_size=3, padding=1)
        self.pool3 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv3.out_size)

        # Dense layers
        self.fc1 = brainstate.nn.Linear(int(np.prod(self.pool3.out_size)), 512)
        self.fc2 = brainstate.nn.Linear(512, num_classes)

    def __call__(self, x):
        # Reshape if needed
        if len(x.shape) == 2:
            x = x.reshape(-1, 32, 32, 3)

        # Conv blocks
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = jax.nn.relu(x)
        x = self.pool3(x)

        # Flatten and FC layers
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)

        return x


class SimpleRNN(brainstate.nn.Module):
    """Simple RNN for testing gradient stability."""

    def __init__(self, input_dim=10, hidden_dim=128, output_dim=10):
        super().__init__()
        self.rnn = brainstate.nn.ValinaRNNCell(input_dim, hidden_dim, num_layers=2)
        self.fc = brainstate.nn.Linear(hidden_dim, output_dim)

    def __call__(self, x):
        # x shape: (batch, seq_len, features)
        outputs = brainstate.transform.for_loop(self.rnn, x)
        # Use last timestep
        return self.fc(outputs[-1])
def create_synthetic_data(data_type='vision', n_samples=1000, seed=42):
    """Create synthetic data for different model types."""
    with brainstate.random.seed_context(seed):
        if data_type == 'vision':
            # Image-like data (32x32x3)
            X = brainstate.random.normal(size=(n_samples, 32, 32, 3)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        elif data_type == 'transformer':
            # Sequence data for transformer (seq_len=64, dim=512)
            X = brainstate.random.normal(size=(n_samples, 64, 512)) * 0.1
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        elif data_type == 'sequence':
            # Sequence data for RNN (seq_len=20, features=10)
            X = brainstate.random.normal(size=(n_samples, 20, 10)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        else:
            # Default: flat features
            X = brainstate.random.normal(size=(n_samples, 784)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))

    return X, y


# Create datasets
X_vision, y_vision = create_synthetic_data('vision', n_samples=2000)
X_transformer, y_transformer = create_synthetic_data('transformer', n_samples=1000)
X_sequence, y_sequence = create_synthetic_data('sequence', n_samples=2000)

print(f"Vision data shape: {X_vision.shape}")
print(f"Transformer data shape: {X_transformer.shape}")
print(f"Sequence data shape: {X_sequence.shape}")

2. Gradient Computation and Training Infrastructure#

Following the style from previous tutorials, we’ll set up our gradient computation.

def compute_loss_and_grads(model, X, y, param_states, loss_type='classification'):
    """Compute loss and gradients following braintools style."""

    def loss_fn():
        # Forward pass
        outputs = model(X)

        if loss_type == 'classification':
            # Cross-entropy loss
            log_probs = jax.nn.log_softmax(outputs, axis=-1)
            one_hot = jax.nn.one_hot(y, num_classes=10)
            loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
        else:
            # MSE loss for regression
            loss = jnp.mean((outputs - y) ** 2)

        # Add L2 regularization
        l2_reg = 1e-4
        for state in param_states.values():
            loss = loss + l2_reg * jnp.sum(state.value ** 2)

        return loss

    # Compute loss and gradients
    loss = loss_fn()
    grads = brainstate.transform.grad(loss_fn, grad_states=param_states)()

    # Compute accuracy for classification
    if loss_type == 'classification':
        outputs = model(X)
        predictions = jnp.argmax(outputs, axis=-1)
        accuracy = jnp.mean(predictions == y)
    else:
        accuracy = -loss  # Use negative loss as metric for regression

    return loss, grads, accuracy


def train_with_optimizer(
    model: brainstate.nn.Module,
    optimizer: braintools.optim.OptaxOptimizer,
    X_train, y_train,
    X_val, y_val,
    n_epochs=30,
    batch_size=64,
    verbose=True
):
    """Generic training function for any optimizer."""

    # Get parameter states
    param_states = braintools.optim.UniqueStateManager(
        model.states(brainstate.ParamState)
    ).to_pytree()

    # Register parameters with optimizer
    optimizer.register_trainable_weights(param_states)

    @brainstate.transform.jit
    def train_step(X_batch, y_batch):
        loss, grads, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
        optimizer.update(grads)
        return loss, acc

    @brainstate.transform.jit
    def eval_step(X_batch, y_batch):
        loss, _, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
        return loss, acc

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'epoch_time': []
    }

    n_batches = len(X_train) // batch_size

    for epoch in range(n_epochs):
        epoch_start = time.time()

        # Shuffle data
        perm = brainstate.random.permutation(len(X_train))
        X_train_shuffled = X_train[perm]
        y_train_shuffled = y_train[perm]

        train_losses = []
        train_accs = []

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            X_batch = X_train_shuffled[start_idx:end_idx]
            y_batch = y_train_shuffled[start_idx:end_idx]

            loss, acc = train_step(X_batch, y_batch)
            train_losses.append(float(loss))
            train_accs.append(float(acc))

        # Validation
        val_loss, val_acc = eval_step(X_val[:500], y_val[:500])  # Use subset for speed

        # Update learning rate if scheduler is attached
        optimizer.lr.step()

        # Record metrics
        history['train_loss'].append(np.mean(train_losses))
        history['train_acc'].append(np.mean(train_accs))
        history['val_loss'].append(float(val_loss))
        history['val_acc'].append(float(val_acc))
        history['epoch_time'].append(time.time() - epoch_start)

        if verbose and (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{n_epochs} - "
                  f"Loss: {history['train_loss'][-1]:.4f}, "
                  f"Acc: {history['train_acc'][-1]:.4f}, "
                  f"Val Loss: {history['val_loss'][-1]:.4f}, "
                  f"Val Acc: {history['val_acc'][-1]:.4f}")

    return history

3. Specialized Gradient-Based Optimizers#

Let’s explore advanced optimizers designed for specific scenarios.

3.1 Lion Optimizer - Memory Efficient#

Lion (EvoLved Sign Momentum) is a memory-efficient optimizer that uses sign updates.

# Lion optimizer
model_lion = CNNModel()

lion_optimizer = braintools.optim.Lion(
    lr=3e-4,  # Lion typically uses smaller learning rates
    betas=(0.9, 0.99),
    weight_decay=1e-4
)

print("Training with Lion optimizer (memory-efficient)...")
history_lion = train_with_optimizer(
    model_lion, lion_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

3.2 Adafactor - Factorized Second Moments#

Adafactor reduces memory usage by factorizing the second moment estimation.

# Adafactor optimizer
model_adafactor = TransformerBlock()

adafactor_optimizer = braintools.optim.Adafactor(
    lr=1e-3,
    decay_rate=0.8,
    factored=True,  # Enable factorization for memory efficiency
    clip_threshold=1.0
)

print("Training with Adafactor (factorized second moments)...")
history_adafactor = train_with_optimizer(
    model_adafactor, adafactor_optimizer,
    X_transformer[:500], y_transformer[:500],
    X_transformer[500:700], y_transformer[500:700],
    n_epochs=30, batch_size=16
)

3.3 Lookahead Optimizer - k-step Forward#

Lookahead maintains two sets of weights and performs k-step forward optimization.

# Lookahead optimizer wrapping SGD
model_lookahead = CNNModel()

# Base optimizer
base_optimizer = braintools.optim.SGD(lr=0.1, momentum=0.9)

# Wrap with Lookahead
lookahead_optimizer = braintools.optim.Lookahead(
    base_optimizer,
    sync_period=5,  # Update slow weights every 5 steps
    alpha=0.5  # Interpolation factor
)

print("Training with Lookahead optimizer (k-step forward)...")
history_lookahead = train_with_optimizer(
    model_lookahead, lookahead_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

3.4 RAdam - Rectified Adam#

RAdam rectifies the variance of the adaptive learning rate to stabilize training.

# RAdam optimizer
model_radam = SimpleRNN()

radam_optimizer = braintools.optim.RAdam(
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=1e-4
)

print("Training with RAdam (Rectified Adam)...")
history_radam = train_with_optimizer(
    model_radam, radam_optimizer,
    X_sequence[:1000], y_sequence[:1000],
    X_sequence[1000:1500], y_sequence[1000:1500],
    n_epochs=30, batch_size=32
)

4. Large-Scale Training Optimizers#

These optimizers are designed for training with large batch sizes and distributed settings.

4.1 LAMB - Layer-wise Adaptive Large Batch#

LAMB enables large batch training by adapting the learning rate per layer.

# LAMB optimizer for large batch training
model_lamb = TransformerBlock()

lamb_optimizer = braintools.optim.Lamb(
    lr=2e-3,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=0.01,
    grad_clip_value=10.0  # Gradient clipping
)

print("Training with LAMB (Large Batch optimizer)...")
# Simulate large batch by using larger batch size
history_lamb = train_with_optimizer(
    model_lamb, lamb_optimizer,
    X_transformer[:800], y_transformer[:800],
    X_transformer[800:], y_transformer[800:],
    n_epochs=30, batch_size=128  # Large batch size
)

4.2 LARS - Layer-wise Adaptive Rate Scaling#

LARS adapts the learning rate for each layer based on the ratio of weight and gradient norms.

# LARS optimizer
model_lars = CNNModel()

lars_optimizer = braintools.optim.Lars(
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4,
    trust_coefficient=0.001,  # LARS-specific parameter
    eps=1e-8
)

print("Training with LARS (Layer-wise Adaptive Rate Scaling)...")
history_lars = train_with_optimizer(
    model_lars, lars_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=128
)

4.3 SM3 - Memory-Efficient for Large Models#

SM3 uses a memory-efficient approximation of adaptive learning rates.

# SM3 optimizer for memory efficiency
model_sm3 = TransformerBlock()

sm3_optimizer = braintools.optim.SM3(
    lr=1e-3,
    momentum=0.9,
    eps=1e-8
)

print("Training with SM3 (Memory-efficient optimizer)...")
history_sm3 = train_with_optimizer(
    model_sm3, sm3_optimizer,
    X_transformer[:500], y_transformer[:500],
    X_transformer[500:700], y_transformer[500:700],
    n_epochs=30, batch_size=16
)

5. Alternative Optimization Paradigms#

These optimizers use different principles than standard gradient descent.

5.1 L-BFGS - Quasi-Newton Method#

L-BFGS approximates the Hessian matrix for second-order optimization.

# L-BFGS optimizer (Note: requires special handling)
from brainstate.nn import Linear


class SimpleMLP(brainstate.nn.Module):
    """Simple MLP for L-BFGS testing."""

    def __init__(self):
        super().__init__()
        self.fc1 = Linear(784, 128)
        self.fc2 = Linear(128, 10)

    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)
        return x


model_lbfgs = SimpleMLP()

# L-BFGS requires full-batch training
lbfgs_optimizer = braintools.optim.LBFGS(
    lr=1.0,
    memory_size=10,
    line_search_fn='zoom'
)

print("Training with L-BFGS (Quasi-Newton method)...")
# Note: L-BFGS typically works better with full-batch
X_small = X_vision[:200].reshape(200, -1)
y_small = y_vision[:200]
X_val_small = X_vision[1000:1100].reshape(100, -1)
y_val_small = y_vision[1000:1100]

history_lbfgs = train_with_optimizer(
    model_lbfgs, lbfgs_optimizer,
    X_small, y_small,
    X_val_small, y_val_small,
    n_epochs=20, batch_size=200  # Full batch
)

5.2 Rprop - Resilient Backpropagation#

Rprop uses only the sign of the gradient and adapts step sizes individually.

# Rprop optimizer
model_rprop = SimpleMLP()

rprop_optimizer = braintools.optim.Rprop(
    lr=1e-3,
    etas=(0.5, 1.2),  # Step size adaptation factors
    step_sizes=(1e-6, 50)  # Min and max step sizes
)

print("Training with Rprop (Resilient Backpropagation)...")
history_rprop = train_with_optimizer(
    model_rprop, rprop_optimizer,
    X_small, y_small,
    X_val_small, y_val_small,
    n_epochs=30, batch_size=32
)

5.3 Yogi - Additive Adaptive Methods#

Yogi uses additive updates instead of multiplicative for better convergence.

# Yogi optimizer
model_yogi = CNNModel()

yogi_optimizer = braintools.optim.Yogi(
    lr=1e-2,
    betas=(0.9, 0.999),
    eps=1e-3  # Yogi typically uses larger epsilon
)

print("Training with Yogi (Additive adaptive method)...")
history_yogi = train_with_optimizer(
    model_yogi, yogi_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

6. Comparing Optimizer Performance#

Let’s visualize and compare the performance of different optimizer categories.

def plot_optimizer_comparison(histories, names, title="Optimizer Comparison"):
    """Create comprehensive comparison plots."""

    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(3, 3, figure=fig)

    # Define color palette
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories)))

    # Training loss
    ax1 = fig.add_subplot(gs[0, 0])
    for hist, name, color in zip(histories, names, colors):
        ax1.plot(hist['train_loss'], label=name, color=color, linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Training Loss')
    ax1.legend(fontsize=8)
    ax1.grid(True, alpha=0.3)

    # Validation loss
    ax2 = fig.add_subplot(gs[0, 1])
    for hist, name, color in zip(histories, names, colors):
        ax2.plot(hist['val_loss'], label=name, color=color, linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Validation Loss')
    ax2.set_title('Validation Loss')
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)

    # Training accuracy
    ax3 = fig.add_subplot(gs[0, 2])
    for hist, name, color in zip(histories, names, colors):
        ax3.plot(hist['train_acc'], label=name, color=color, linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Training Accuracy')
    ax3.set_title('Training Accuracy')
    ax3.legend(fontsize=8)
    ax3.grid(True, alpha=0.3)

    # Convergence speed (loss reduction)
    ax4 = fig.add_subplot(gs[1, 0])
    for hist, name, color in zip(histories, names, colors):
        loss_reduction = np.array(hist['train_loss']) / hist['train_loss'][0]
        ax4.plot(loss_reduction, label=name, color=color, linewidth=2)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss Reduction Ratio')
    ax4.set_title('Convergence Speed')
    ax4.legend(fontsize=8)
    ax4.grid(True, alpha=0.3)

    # Training time per epoch
    ax5 = fig.add_subplot(gs[1, 1])
    avg_times = [np.mean(hist['epoch_time']) for hist in histories]
    bars = ax5.bar(range(len(names)), avg_times, color=colors)
    ax5.set_xticks(range(len(names)))
    ax5.set_xticklabels(names, rotation=45, ha='right')
    ax5.set_ylabel('Average Time per Epoch (s)')
    ax5.set_title('Training Efficiency')
    ax5.grid(True, alpha=0.3, axis='y')

    # Final performance comparison
    ax6 = fig.add_subplot(gs[1, 2])
    final_train_loss = [hist['train_loss'][-1] for hist in histories]
    final_val_loss = [hist['val_loss'][-1] for hist in histories]

    x = np.arange(len(names))
    width = 0.35

    bars1 = ax6.bar(x - width / 2, final_train_loss, width, label='Train Loss', color='steelblue')
    bars2 = ax6.bar(x + width / 2, final_val_loss, width, label='Val Loss', color='coral')

    ax6.set_xticks(x)
    ax6.set_xticklabels(names, rotation=45, ha='right')
    ax6.set_ylabel('Final Loss')
    ax6.set_title('Final Performance')
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')

    # Loss landscape smoothness (variance of loss)
    ax7 = fig.add_subplot(gs[2, 0])
    for hist, name, color in zip(histories, names, colors):
        # Calculate rolling variance
        window = 5
        loss_array = np.array(hist['train_loss'])
        if len(loss_array) >= window:
            rolling_var = np.convolve(
                (loss_array - np.mean(loss_array)) ** 2,
                np.ones(window) / window,
                mode='valid'
            )
            ax7.plot(rolling_var, label=name, color=color, linewidth=2)
    ax7.set_xlabel('Epoch')
    ax7.set_ylabel('Loss Variance')
    ax7.set_title('Training Stability')
    ax7.legend(fontsize=8)
    ax7.grid(True, alpha=0.3)

    # Memory usage estimate (simplified)
    ax8 = fig.add_subplot(gs[2, 1:])  # Span two columns

    # Optimizer memory footprint (relative estimates)
    memory_factors = {
        'Lion': 0.5,  # Sign-based, very memory efficient
        'Adafactor': 0.6,  # Factorized moments
        'SM3': 0.7,  # Sparse second moments
        'Rprop': 0.8,  # Only step sizes
        'SGD': 0.9,  # Momentum only
        'Adam': 1.0,  # Baseline (first and second moments)
        'RAdam': 1.0,  # Same as Adam
        'Yogi': 1.0,  # Similar to Adam
        'Lookahead': 1.5,  # Two sets of weights
        'LAMB': 1.2,  # Layer-wise adaptation
        'LARS': 1.1,  # Layer-wise scaling
        'L-BFGS': 2.0,  # History of gradients
    }

    mem_values = [memory_factors.get(name, 1.0) for name in names]
    bars = ax8.barh(range(len(names)), mem_values, color=colors)
    ax8.set_yticks(range(len(names)))
    ax8.set_yticklabels(names)
    ax8.set_xlabel('Relative Memory Usage')
    ax8.set_title('Memory Efficiency Comparison')
    ax8.grid(True, alpha=0.3, axis='x')

    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Compare specialized optimizers
specialized_histories = [history_lion, history_adafactor, history_radam, history_yogi]
specialized_names = ['Lion', 'Adafactor', 'RAdam', 'Yogi']

plot_optimizer_comparison(
    specialized_histories,
    specialized_names,
    "Specialized Gradient-Based Optimizers"
)
# Compare large-scale optimizers
largescale_histories = [history_lamb, history_lars, history_sm3]
largescale_names = ['LAMB', 'LARS', 'SM3']

plot_optimizer_comparison(
    largescale_histories,
    largescale_names,
    "Large-Scale Training Optimizers"
)

7. Gradient-Free Optimization#

For gradient-free optimization, braintools provides integration with specialized libraries.

7.1 Nevergrad Integration#

Nevergrad provides a wide range of gradient-free optimization algorithms, please refer to the nevergrad tutorial documentation for details.

7.2 SciPy Optimization#

SciPy provides classical optimization algorithms including constrained optimization, please refer to the scipy tutorial documentation for details.

Summary and Best Practices#

Key Takeaways

  1. Memory-Efficient Optimizers

    • Lion: Best for very large models with memory constraints

    • Adafactor: Good balance of memory and performance

    • SM3: Excellent for sparse models

  2. Large-Scale Training

    • LAMB/LARS: Essential for large batch training

    • Enable linear scaling of batch size with learning rate

    • Critical for distributed training

  3. Stability and Robustness

    • RAdam: Rectified variance for stability

    • Lookahead: Reduces variance through averaging

    • Yogi: Additive updates for better convergence

  4. Alternative Paradigms

    • L-BFGS: Excellent for small datasets with second-order information

    • Rprop: Robust to gradient noise

    • Gradient-free: When gradients are unavailable or unreliable

When to Use Advanced Optimizers

Scenario

Recommended Optimizer

Reason

Large Language Models

Lion, Adafactor

Memory efficiency

Distributed Training

LAMB, LARS

Large batch handling

Noisy Gradients

RAdam, Lookahead

Stability

Small Dataset

L-BFGS

Fast convergence

Research/Experimentation

Yogi, Custom

Novel behaviors

Constrained Optimization

ScipyOptimizer

Built-in constraints

Black-box Optimization

NevergradOptimizer

No gradients needed

Exercises#

  1. Memory Comparison: Train the same large model with Adam, Lion, and Adafactor. Monitor and compare memory usage.

  2. Large Batch Scaling: Test how well different optimizers handle increasing batch sizes from 32 to 1024.

  3. Stability Analysis: Add artificial noise to gradients and compare optimizer robustness.

  4. Hybrid Approach: Implement a training schedule that switches optimizers (e.g., Adam → L-BFGS for fine-tuning).

  5. Custom Optimizer: Create your own optimizer by combining ideas from different methods.

  6. Constraint Satisfaction: Use ScipyOptimizer to solve a constrained optimization problem in neural network training.

  7. Hyperparameter Optimization: Use NevergradOptimizer to tune the hyperparameters of another optimizer.