Tutorial 5: Advanced Optimizers and Techniques#
Difficulty: Advanced
Duration: 40-50 minutes
Prerequisites: Tutorials 3 and 4 completion
Learning Objectives#
Use specialized optimizers for specific scenarios
Implement second-order optimization methods
Apply gradient-free optimization
Understand memory-efficient optimizers
Topics Covered#
Specialized gradient-based optimizers
Lion: Memory-efficient optimizer
Adafactor: Factorized second moments
Lookahead: k-step forward optimization
RAdam: Rectified Adam
Large-scale training optimizers
LAMB: Layer-wise adaptive large batch
LARS: Layer-wise adaptive rate scaling
SM3: Memory-efficient for large models
Alternative optimization paradigms
LBFGS: Quasi-Newton method
Rprop: Resilient backpropagation
Yogi: Additive adaptive methods
Gradient-free optimization
NevergradOptimizer integration
ScipyOptimizer for constrained problems
import time
import brainstate
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec
import braintools
1. Setting up Test Models and Data#
We’ll create different model architectures to test various optimizer characteristics.
class TransformerBlock(brainstate.nn.Module):
"""Simplified Transformer block for testing large-scale optimizers."""
def __init__(self, dim=512, num_heads=8, mlp_ratio=4.0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
# Multi-head attention components
self.qkv = brainstate.nn.Linear(dim, dim * 3)
self.proj = brainstate.nn.Linear(dim, dim)
# MLP components
mlp_hidden_dim = int(dim * mlp_ratio)
self.fc1 = brainstate.nn.Linear(dim, mlp_hidden_dim)
self.fc2 = brainstate.nn.Linear(mlp_hidden_dim, dim)
# Layer norms
self.norm1 = brainstate.nn.LayerNorm(dim)
self.norm2 = brainstate.nn.LayerNorm(dim)
def __call__(self, x):
# Simplified attention (without actual attention computation)
residual = x
x = self.norm1(x)
# QKV projection
qkv = self.qkv(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
# Simplified attention output (just use v for demonstration)
attn_output = self.proj(v)
x = residual + attn_output
# MLP block
residual = x
x = self.norm2(x)
x = self.fc1(x)
x = jax.nn.gelu(x)
x = self.fc2(x)
x = residual + x
return x
class CNNModel(brainstate.nn.Module):
"""CNN for testing memory-efficient optimizers."""
def __init__(self, in_size, num_classes=10):
super().__init__()
# Conv layers
self.conv1 = brainstate.nn.Conv2d(in_size, 64, kernel_size=3, padding=1)
self.pool1 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv1.out_size)
self.conv2 = brainstate.nn.Conv2d(self.pool1.out_size, 128, kernel_size=3, padding=1)
self.pool2 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv2.out_size)
self.conv3 = brainstate.nn.Conv2d(self.pool2.out_size, 256, kernel_size=3, padding=1)
self.pool3 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv3.out_size)
# Dense layers
self.fc1 = brainstate.nn.Linear(int(np.prod(self.pool3.out_size)), 512)
self.fc2 = brainstate.nn.Linear(512, num_classes)
def __call__(self, x):
# Reshape if needed
if len(x.shape) == 2:
x = x.reshape(-1, 32, 32, 3)
# Conv blocks
x = self.conv1(x)
x = jax.nn.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = jax.nn.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = jax.nn.relu(x)
x = self.pool3(x)
# Flatten and FC layers
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
x = jax.nn.relu(x)
x = self.fc2(x)
return x
class SimpleRNN(brainstate.nn.Module):
"""Simple RNN for testing gradient stability."""
def __init__(self, input_dim=10, hidden_dim=128, output_dim=10):
super().__init__()
self.rnn = brainstate.nn.ValinaRNNCell(input_dim, hidden_dim, num_layers=2)
self.fc = brainstate.nn.Linear(hidden_dim, output_dim)
def __call__(self, x):
# x shape: (batch, seq_len, features)
outputs = brainstate.transform.for_loop(self.rnn, x)
# Use last timestep
return self.fc(outputs[-1])
def create_synthetic_data(data_type='vision', n_samples=1000, seed=42):
"""Create synthetic data for different model types."""
with brainstate.random.seed_context(seed):
if data_type == 'vision':
# Image-like data (32x32x3)
X = brainstate.random.normal(size=(n_samples, 32, 32, 3)) * 0.5
y = brainstate.random.randint(0, 10, size=(n_samples,))
elif data_type == 'transformer':
# Sequence data for transformer (seq_len=64, dim=512)
X = brainstate.random.normal(size=(n_samples, 64, 512)) * 0.1
y = brainstate.random.randint(0, 10, size=(n_samples,))
elif data_type == 'sequence':
# Sequence data for RNN (seq_len=20, features=10)
X = brainstate.random.normal(size=(n_samples, 20, 10)) * 0.5
y = brainstate.random.randint(0, 10, size=(n_samples,))
else:
# Default: flat features
X = brainstate.random.normal(size=(n_samples, 784)) * 0.5
y = brainstate.random.randint(0, 10, size=(n_samples,))
return X, y
# Create datasets
X_vision, y_vision = create_synthetic_data('vision', n_samples=2000)
X_transformer, y_transformer = create_synthetic_data('transformer', n_samples=1000)
X_sequence, y_sequence = create_synthetic_data('sequence', n_samples=2000)
print(f"Vision data shape: {X_vision.shape}")
print(f"Transformer data shape: {X_transformer.shape}")
print(f"Sequence data shape: {X_sequence.shape}")
2. Gradient Computation and Training Infrastructure#
Following the style from previous tutorials, we’ll set up our gradient computation.
def compute_loss_and_grads(model, X, y, param_states, loss_type='classification'):
"""Compute loss and gradients following braintools style."""
def loss_fn():
# Forward pass
outputs = model(X)
if loss_type == 'classification':
# Cross-entropy loss
log_probs = jax.nn.log_softmax(outputs, axis=-1)
one_hot = jax.nn.one_hot(y, num_classes=10)
loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
else:
# MSE loss for regression
loss = jnp.mean((outputs - y) ** 2)
# Add L2 regularization
l2_reg = 1e-4
for state in param_states.values():
loss = loss + l2_reg * jnp.sum(state.value ** 2)
return loss
# Compute loss and gradients
loss = loss_fn()
grads = brainstate.transform.grad(loss_fn, grad_states=param_states)()
# Compute accuracy for classification
if loss_type == 'classification':
outputs = model(X)
predictions = jnp.argmax(outputs, axis=-1)
accuracy = jnp.mean(predictions == y)
else:
accuracy = -loss # Use negative loss as metric for regression
return loss, grads, accuracy
def train_with_optimizer(
model: brainstate.nn.Module,
optimizer: braintools.optim.OptaxOptimizer,
X_train, y_train,
X_val, y_val,
n_epochs=30,
batch_size=64,
verbose=True
):
"""Generic training function for any optimizer."""
# Get parameter states
param_states = braintools.optim.UniqueStateManager(
model.states(brainstate.ParamState)
).to_pytree()
# Register parameters with optimizer
optimizer.register_trainable_weights(param_states)
@brainstate.transform.jit
def train_step(X_batch, y_batch):
loss, grads, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
optimizer.update(grads)
return loss, acc
@brainstate.transform.jit
def eval_step(X_batch, y_batch):
loss, _, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
return loss, acc
history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'epoch_time': []
}
n_batches = len(X_train) // batch_size
for epoch in range(n_epochs):
epoch_start = time.time()
# Shuffle data
perm = brainstate.random.permutation(len(X_train))
X_train_shuffled = X_train[perm]
y_train_shuffled = y_train[perm]
train_losses = []
train_accs = []
for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = start_idx + batch_size
X_batch = X_train_shuffled[start_idx:end_idx]
y_batch = y_train_shuffled[start_idx:end_idx]
loss, acc = train_step(X_batch, y_batch)
train_losses.append(float(loss))
train_accs.append(float(acc))
# Validation
val_loss, val_acc = eval_step(X_val[:500], y_val[:500]) # Use subset for speed
# Update learning rate if scheduler is attached
optimizer.lr.step()
# Record metrics
history['train_loss'].append(np.mean(train_losses))
history['train_acc'].append(np.mean(train_accs))
history['val_loss'].append(float(val_loss))
history['val_acc'].append(float(val_acc))
history['epoch_time'].append(time.time() - epoch_start)
if verbose and (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}/{n_epochs} - "
f"Loss: {history['train_loss'][-1]:.4f}, "
f"Acc: {history['train_acc'][-1]:.4f}, "
f"Val Loss: {history['val_loss'][-1]:.4f}, "
f"Val Acc: {history['val_acc'][-1]:.4f}")
return history
3. Specialized Gradient-Based Optimizers#
Let’s explore advanced optimizers designed for specific scenarios.
3.1 Lion Optimizer - Memory Efficient#
Lion (EvoLved Sign Momentum) is a memory-efficient optimizer that uses sign updates.
# Lion optimizer
model_lion = CNNModel()
lion_optimizer = braintools.optim.Lion(
lr=3e-4, # Lion typically uses smaller learning rates
betas=(0.9, 0.99),
weight_decay=1e-4
)
print("Training with Lion optimizer (memory-efficient)...")
history_lion = train_with_optimizer(
model_lion, lion_optimizer,
X_vision[:1000], y_vision[:1000],
X_vision[1000:1500], y_vision[1000:1500],
n_epochs=30, batch_size=32
)
3.2 Adafactor - Factorized Second Moments#
Adafactor reduces memory usage by factorizing the second moment estimation.
# Adafactor optimizer
model_adafactor = TransformerBlock()
adafactor_optimizer = braintools.optim.Adafactor(
lr=1e-3,
decay_rate=0.8,
factored=True, # Enable factorization for memory efficiency
clip_threshold=1.0
)
print("Training with Adafactor (factorized second moments)...")
history_adafactor = train_with_optimizer(
model_adafactor, adafactor_optimizer,
X_transformer[:500], y_transformer[:500],
X_transformer[500:700], y_transformer[500:700],
n_epochs=30, batch_size=16
)
3.3 Lookahead Optimizer - k-step Forward#
Lookahead maintains two sets of weights and performs k-step forward optimization.
# Lookahead optimizer wrapping SGD
model_lookahead = CNNModel()
# Base optimizer
base_optimizer = braintools.optim.SGD(lr=0.1, momentum=0.9)
# Wrap with Lookahead
lookahead_optimizer = braintools.optim.Lookahead(
base_optimizer,
sync_period=5, # Update slow weights every 5 steps
alpha=0.5 # Interpolation factor
)
print("Training with Lookahead optimizer (k-step forward)...")
history_lookahead = train_with_optimizer(
model_lookahead, lookahead_optimizer,
X_vision[:1000], y_vision[:1000],
X_vision[1000:1500], y_vision[1000:1500],
n_epochs=30, batch_size=32
)
3.4 RAdam - Rectified Adam#
RAdam rectifies the variance of the adaptive learning rate to stabilize training.
# RAdam optimizer
model_radam = SimpleRNN()
radam_optimizer = braintools.optim.RAdam(
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-4
)
print("Training with RAdam (Rectified Adam)...")
history_radam = train_with_optimizer(
model_radam, radam_optimizer,
X_sequence[:1000], y_sequence[:1000],
X_sequence[1000:1500], y_sequence[1000:1500],
n_epochs=30, batch_size=32
)
4. Large-Scale Training Optimizers#
These optimizers are designed for training with large batch sizes and distributed settings.
4.1 LAMB - Layer-wise Adaptive Large Batch#
LAMB enables large batch training by adapting the learning rate per layer.
# LAMB optimizer for large batch training
model_lamb = TransformerBlock()
lamb_optimizer = braintools.optim.Lamb(
lr=2e-3,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_clip_value=10.0 # Gradient clipping
)
print("Training with LAMB (Large Batch optimizer)...")
# Simulate large batch by using larger batch size
history_lamb = train_with_optimizer(
model_lamb, lamb_optimizer,
X_transformer[:800], y_transformer[:800],
X_transformer[800:], y_transformer[800:],
n_epochs=30, batch_size=128 # Large batch size
)
4.2 LARS - Layer-wise Adaptive Rate Scaling#
LARS adapts the learning rate for each layer based on the ratio of weight and gradient norms.
# LARS optimizer
model_lars = CNNModel()
lars_optimizer = braintools.optim.Lars(
lr=0.1,
momentum=0.9,
weight_decay=1e-4,
trust_coefficient=0.001, # LARS-specific parameter
eps=1e-8
)
print("Training with LARS (Layer-wise Adaptive Rate Scaling)...")
history_lars = train_with_optimizer(
model_lars, lars_optimizer,
X_vision[:1000], y_vision[:1000],
X_vision[1000:1500], y_vision[1000:1500],
n_epochs=30, batch_size=128
)
4.3 SM3 - Memory-Efficient for Large Models#
SM3 uses a memory-efficient approximation of adaptive learning rates.
# SM3 optimizer for memory efficiency
model_sm3 = TransformerBlock()
sm3_optimizer = braintools.optim.SM3(
lr=1e-3,
momentum=0.9,
eps=1e-8
)
print("Training with SM3 (Memory-efficient optimizer)...")
history_sm3 = train_with_optimizer(
model_sm3, sm3_optimizer,
X_transformer[:500], y_transformer[:500],
X_transformer[500:700], y_transformer[500:700],
n_epochs=30, batch_size=16
)
5. Alternative Optimization Paradigms#
These optimizers use different principles than standard gradient descent.
5.1 L-BFGS - Quasi-Newton Method#
L-BFGS approximates the Hessian matrix for second-order optimization.
# L-BFGS optimizer (Note: requires special handling)
from brainstate.nn import Linear
class SimpleMLP(brainstate.nn.Module):
"""Simple MLP for L-BFGS testing."""
def __init__(self):
super().__init__()
self.fc1 = Linear(784, 128)
self.fc2 = Linear(128, 10)
def __call__(self, x):
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
x = jax.nn.relu(x)
x = self.fc2(x)
return x
model_lbfgs = SimpleMLP()
# L-BFGS requires full-batch training
lbfgs_optimizer = braintools.optim.LBFGS(
lr=1.0,
memory_size=10,
line_search_fn='zoom'
)
print("Training with L-BFGS (Quasi-Newton method)...")
# Note: L-BFGS typically works better with full-batch
X_small = X_vision[:200].reshape(200, -1)
y_small = y_vision[:200]
X_val_small = X_vision[1000:1100].reshape(100, -1)
y_val_small = y_vision[1000:1100]
history_lbfgs = train_with_optimizer(
model_lbfgs, lbfgs_optimizer,
X_small, y_small,
X_val_small, y_val_small,
n_epochs=20, batch_size=200 # Full batch
)
5.2 Rprop - Resilient Backpropagation#
Rprop uses only the sign of the gradient and adapts step sizes individually.
# Rprop optimizer
model_rprop = SimpleMLP()
rprop_optimizer = braintools.optim.Rprop(
lr=1e-3,
etas=(0.5, 1.2), # Step size adaptation factors
step_sizes=(1e-6, 50) # Min and max step sizes
)
print("Training with Rprop (Resilient Backpropagation)...")
history_rprop = train_with_optimizer(
model_rprop, rprop_optimizer,
X_small, y_small,
X_val_small, y_val_small,
n_epochs=30, batch_size=32
)
5.3 Yogi - Additive Adaptive Methods#
Yogi uses additive updates instead of multiplicative for better convergence.
# Yogi optimizer
model_yogi = CNNModel()
yogi_optimizer = braintools.optim.Yogi(
lr=1e-2,
betas=(0.9, 0.999),
eps=1e-3 # Yogi typically uses larger epsilon
)
print("Training with Yogi (Additive adaptive method)...")
history_yogi = train_with_optimizer(
model_yogi, yogi_optimizer,
X_vision[:1000], y_vision[:1000],
X_vision[1000:1500], y_vision[1000:1500],
n_epochs=30, batch_size=32
)
6. Comparing Optimizer Performance#
Let’s visualize and compare the performance of different optimizer categories.
def plot_optimizer_comparison(histories, names, title="Optimizer Comparison"):
"""Create comprehensive comparison plots."""
fig = plt.figure(figsize=(16, 10))
gs = GridSpec(3, 3, figure=fig)
# Define color palette
colors = plt.cm.tab10(np.linspace(0, 1, len(histories)))
# Training loss
ax1 = fig.add_subplot(gs[0, 0])
for hist, name, color in zip(histories, names, colors):
ax1.plot(hist['train_loss'], label=name, color=color, linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss')
ax1.set_title('Training Loss')
ax1.legend(fontsize=8)
ax1.grid(True, alpha=0.3)
# Validation loss
ax2 = fig.add_subplot(gs[0, 1])
for hist, name, color in zip(histories, names, colors):
ax2.plot(hist['val_loss'], label=name, color=color, linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Validation Loss')
ax2.set_title('Validation Loss')
ax2.legend(fontsize=8)
ax2.grid(True, alpha=0.3)
# Training accuracy
ax3 = fig.add_subplot(gs[0, 2])
for hist, name, color in zip(histories, names, colors):
ax3.plot(hist['train_acc'], label=name, color=color, linewidth=2)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Training Accuracy')
ax3.set_title('Training Accuracy')
ax3.legend(fontsize=8)
ax3.grid(True, alpha=0.3)
# Convergence speed (loss reduction)
ax4 = fig.add_subplot(gs[1, 0])
for hist, name, color in zip(histories, names, colors):
loss_reduction = np.array(hist['train_loss']) / hist['train_loss'][0]
ax4.plot(loss_reduction, label=name, color=color, linewidth=2)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Loss Reduction Ratio')
ax4.set_title('Convergence Speed')
ax4.legend(fontsize=8)
ax4.grid(True, alpha=0.3)
# Training time per epoch
ax5 = fig.add_subplot(gs[1, 1])
avg_times = [np.mean(hist['epoch_time']) for hist in histories]
bars = ax5.bar(range(len(names)), avg_times, color=colors)
ax5.set_xticks(range(len(names)))
ax5.set_xticklabels(names, rotation=45, ha='right')
ax5.set_ylabel('Average Time per Epoch (s)')
ax5.set_title('Training Efficiency')
ax5.grid(True, alpha=0.3, axis='y')
# Final performance comparison
ax6 = fig.add_subplot(gs[1, 2])
final_train_loss = [hist['train_loss'][-1] for hist in histories]
final_val_loss = [hist['val_loss'][-1] for hist in histories]
x = np.arange(len(names))
width = 0.35
bars1 = ax6.bar(x - width / 2, final_train_loss, width, label='Train Loss', color='steelblue')
bars2 = ax6.bar(x + width / 2, final_val_loss, width, label='Val Loss', color='coral')
ax6.set_xticks(x)
ax6.set_xticklabels(names, rotation=45, ha='right')
ax6.set_ylabel('Final Loss')
ax6.set_title('Final Performance')
ax6.legend()
ax6.grid(True, alpha=0.3, axis='y')
# Loss landscape smoothness (variance of loss)
ax7 = fig.add_subplot(gs[2, 0])
for hist, name, color in zip(histories, names, colors):
# Calculate rolling variance
window = 5
loss_array = np.array(hist['train_loss'])
if len(loss_array) >= window:
rolling_var = np.convolve(
(loss_array - np.mean(loss_array)) ** 2,
np.ones(window) / window,
mode='valid'
)
ax7.plot(rolling_var, label=name, color=color, linewidth=2)
ax7.set_xlabel('Epoch')
ax7.set_ylabel('Loss Variance')
ax7.set_title('Training Stability')
ax7.legend(fontsize=8)
ax7.grid(True, alpha=0.3)
# Memory usage estimate (simplified)
ax8 = fig.add_subplot(gs[2, 1:]) # Span two columns
# Optimizer memory footprint (relative estimates)
memory_factors = {
'Lion': 0.5, # Sign-based, very memory efficient
'Adafactor': 0.6, # Factorized moments
'SM3': 0.7, # Sparse second moments
'Rprop': 0.8, # Only step sizes
'SGD': 0.9, # Momentum only
'Adam': 1.0, # Baseline (first and second moments)
'RAdam': 1.0, # Same as Adam
'Yogi': 1.0, # Similar to Adam
'Lookahead': 1.5, # Two sets of weights
'LAMB': 1.2, # Layer-wise adaptation
'LARS': 1.1, # Layer-wise scaling
'L-BFGS': 2.0, # History of gradients
}
mem_values = [memory_factors.get(name, 1.0) for name in names]
bars = ax8.barh(range(len(names)), mem_values, color=colors)
ax8.set_yticks(range(len(names)))
ax8.set_yticklabels(names)
ax8.set_xlabel('Relative Memory Usage')
ax8.set_title('Memory Efficiency Comparison')
ax8.grid(True, alpha=0.3, axis='x')
plt.suptitle(title, fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
# Compare specialized optimizers
specialized_histories = [history_lion, history_adafactor, history_radam, history_yogi]
specialized_names = ['Lion', 'Adafactor', 'RAdam', 'Yogi']
plot_optimizer_comparison(
specialized_histories,
specialized_names,
"Specialized Gradient-Based Optimizers"
)
# Compare large-scale optimizers
largescale_histories = [history_lamb, history_lars, history_sm3]
largescale_names = ['LAMB', 'LARS', 'SM3']
plot_optimizer_comparison(
largescale_histories,
largescale_names,
"Large-Scale Training Optimizers"
)
7. Gradient-Free Optimization#
For gradient-free optimization, braintools provides integration with specialized libraries.
7.1 Nevergrad Integration#
Nevergrad provides a wide range of gradient-free optimization algorithms, please refer to the nevergrad tutorial documentation for details.
7.2 SciPy Optimization#
SciPy provides classical optimization algorithms including constrained optimization, please refer to the scipy tutorial documentation for details.
Summary and Best Practices#
Key Takeaways
Memory-Efficient Optimizers
Lion: Best for very large models with memory constraints
Adafactor: Good balance of memory and performance
SM3: Excellent for sparse models
Large-Scale Training
LAMB/LARS: Essential for large batch training
Enable linear scaling of batch size with learning rate
Critical for distributed training
Stability and Robustness
RAdam: Rectified variance for stability
Lookahead: Reduces variance through averaging
Yogi: Additive updates for better convergence
Alternative Paradigms
L-BFGS: Excellent for small datasets with second-order information
Rprop: Robust to gradient noise
Gradient-free: When gradients are unavailable or unreliable
When to Use Advanced Optimizers
Scenario |
Recommended Optimizer |
Reason |
|---|---|---|
Large Language Models |
Lion, Adafactor |
Memory efficiency |
Distributed Training |
LAMB, LARS |
Large batch handling |
Noisy Gradients |
RAdam, Lookahead |
Stability |
Small Dataset |
L-BFGS |
Fast convergence |
Research/Experimentation |
Yogi, Custom |
Novel behaviors |
Constrained Optimization |
ScipyOptimizer |
Built-in constraints |
Black-box Optimization |
NevergradOptimizer |
No gradients needed |
Exercises#
Memory Comparison: Train the same large model with Adam, Lion, and Adafactor. Monitor and compare memory usage.
Large Batch Scaling: Test how well different optimizers handle increasing batch sizes from 32 to 1024.
Stability Analysis: Add artificial noise to gradients and compare optimizer robustness.
Hybrid Approach: Implement a training schedule that switches optimizers (e.g., Adam → L-BFGS for fine-tuning).
Custom Optimizer: Create your own optimizer by combining ideas from different methods.
Constraint Satisfaction: Use ScipyOptimizer to solve a constrained optimization problem in neural network training.
Hyperparameter Optimization: Use NevergradOptimizer to tune the hyperparameters of another optimizer.