WarmupCosineSchedule#
- class braintools.optim.WarmupCosineSchedule(base_lr=0.001, warmup_steps=1000, total_steps=10000, warmup_start_lr=0.0, eta_min=0.0, last_epoch=0)#
Warmup + Cosine annealing schedule for smooth training transitions.
WarmupCosineSchedule combines linear warmup with cosine annealing to create a smooth learning rate schedule that’s particularly effective for training deep neural networks from scratch. The schedule starts with a low learning rate, linearly increases to the base rate during warmup, then follows a cosine decay to a minimum value.
This scheduler is widely used in: - Vision Transformers (ViT) and other transformer architectures - Self-supervised learning (SimCLR, BYOL, MAE) - Large-scale distributed training - Fine-tuning pre-trained models
- Parameters:
base_lr (
float|List[float]) – Peak learning rate(s) reached at the end of warmup. This is the maximum learning rate in the schedule. Can be a single float or list for multiple parameter groups. Default: 1e-3.warmup_steps (
int) – Number of steps for the linear warmup phase. During this phase, the learning rate linearly increases from warmup_start_lr to base_lr. Default: 1000.total_steps (
int) – Total number of training steps. The cosine annealing phase spans from warmup_steps to total_steps. Default: 10000.warmup_start_lr (
float) – Starting learning rate for the warmup phase. Set to 0 for linear warmup from zero, or a small value (e.g., 1e-6) for stability. Default: 0.0.eta_min (
float) – Minimum learning rate at the end of cosine annealing. The learning rate will not go below this value. Default: 0.0.last_epoch (
int) – The index of the last epoch. Used when resuming training. Default: 0.
Notes
Mathematical Formulation:
The learning rate schedule consists of two phases:
Linear Warmup Phase (step < warmup_steps):
\[\eta_t = \eta_{warmup_start} + \frac{t}{T_{warmup}} \cdot (\eta_{base} - \eta_{warmup_start})\]Cosine Annealing Phase (step >= warmup_steps):
\[\eta_t = \eta_{min} + \frac{1}{2}(\eta_{base} - \eta_{min}) \cdot \left(1 + \cos\left(\pi \cdot \frac{t - T_{warmup}} {T_{total} - T_{warmup}}\right)\right)\]where: - \(t\) is the current step - \(T_{warmup}\) is the number of warmup steps - \(T_{total}\) is the total number of steps - \(\eta_{base}\) is the peak learning rate - \(\eta_{min}\) is the minimum learning rate
Benefits of Warmup:
Stability: Prevents divergence in early training with large learning rates
Gradient Statistics: Allows optimizer to gather statistics before full LR
Weight Initialization: Gives random weights time to organize
Large Batch Training: Essential for stable training with large batches
JIT Compatibility:
This implementation is fully JIT-compatible through the use of JAX operations and conditional selection with jnp.where.
Examples
Basic Vision Transformer training:
>>> import braintools >>> import brainstate >>> >>> # Standard ViT training schedule >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.001, # Peak learning rate ... warmup_steps=10000, # 10k warmup steps ... total_steps=100000, # 100k total steps ... warmup_start_lr=1e-6, # Start from small LR ... eta_min=1e-5 # End at small LR ... ) >>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.05) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in range(100000): ... loss = train_step(...) ... scheduler.step() ... if step % 1000 == 0: ... lr = scheduler.get_lr()[0] ... print(f"Step {step}: LR = {lr:.6f}")
Self-supervised learning (SimCLR/BYOL style):
>>> # Self-supervised learning benefits from longer warmup >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.3, # High LR for contrastive learning ... warmup_steps=4000, # ~10 epochs warmup ... total_steps=40000, # 100 epochs total ... warmup_start_lr=0.0, # Start from zero ... eta_min=0.0 # Decay to zero ... ) >>> >>> # Scale learning rate with batch size (linear scaling rule) >>> batch_size = 4096 >>> base_batch_size = 256 >>> scaled_lr = 0.3 * (batch_size / base_batch_size) >>> scheduler.base_lrs = [scaled_lr]
Fine-tuning pre-trained models:
>>> # Shorter warmup for fine-tuning >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=5e-5, # Lower LR for fine-tuning ... warmup_steps=500, # Quick warmup ... total_steps=10000, # Shorter total training ... warmup_start_lr=0.0, ... eta_min=1e-6 ... ) >>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.01)
Distributed training with large batches:
>>> # Large batch training needs careful warmup >>> world_size = 8 # 8 GPUs >>> batch_per_gpu = 64 >>> total_batch = batch_per_gpu * world_size # 512 >>> >>> # Linear scaling rule with warmup >>> base_lr = 0.001 >>> scaled_lr = base_lr * (total_batch / 128) # Scale from base batch 128 >>> >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=scaled_lr, ... warmup_steps=2000, # Longer warmup for large batch ... total_steps=50000, ... warmup_start_lr=base_lr / 100, # Start at 1% of base ... eta_min=scaled_lr * 0.01 ... )
MAE-style masked autoencoder training:
>>> # MAE uses specific warmup schedule >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=1.5e-4, # Base LR for batch 4096 ... warmup_steps=40 * 312, # 40 epochs of warmup ... total_steps=1600 * 312, # 1600 epochs total ... warmup_start_lr=0.0, ... eta_min=0.0 ... ) >>> >>> # Combined with specific optimizer settings >>> optimizer = braintools.optim.AdamW( ... lr=scheduler, ... betas=(0.9, 0.95), # MAE-specific betas ... weight_decay=0.05 ... )
BERT-style transformer training:
>>> # BERT uses fraction-based warmup >>> total_steps = 1000000 >>> warmup_fraction = 0.1 # 10% warmup >>> >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=1e-4, ... warmup_steps=int(total_steps * warmup_fraction), ... total_steps=total_steps, ... warmup_start_lr=0.0, ... eta_min=1e-5 ... )
Monitoring warmup progression:
>>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.001, ... warmup_steps=1000, ... total_steps=10000, ... warmup_start_lr=1e-5, ... eta_min=1e-4 ... ) >>> >>> for step in range(10000): ... scheduler.step() ... lr = scheduler.get_lr()[0] ... ... # Track phase transitions ... if step == 0: ... print(f"Starting warmup from LR = {lr:.6f}") ... elif step == scheduler.warmup_steps - 1: ... print(f"Ending warmup at LR = {lr:.6f}") ... elif step == scheduler.warmup_steps: ... print(f"Starting cosine decay from LR = {lr:.6f}") ... elif step == scheduler.total_steps - 1: ... print(f"Training complete at LR = {lr:.6f}")
Custom warmup strategies:
>>> # Aggressive warmup (reach peak quickly) >>> fast_warmup = braintools.optim.WarmupCosineSchedule( ... base_lr=0.01, ... warmup_steps=100, # Very short warmup ... total_steps=10000, ... warmup_start_lr=0.001, # Start at 10% of peak ... eta_min=0.0001 ... ) >>> >>> # Conservative warmup (gradual increase) >>> slow_warmup = braintools.optim.WarmupCosineSchedule( ... base_lr=0.01, ... warmup_steps=5000, # 50% of training for warmup ... total_steps=10000, ... warmup_start_lr=1e-6, # Start very low ... eta_min=0.001 ... )
Combining with gradient accumulation:
>>> # Gradient accumulation affects effective batch size >>> accumulation_steps = 4 >>> per_step_batch = 32 >>> effective_batch = per_step_batch * accumulation_steps # 128 >>> >>> # Adjust learning rate and warmup accordingly >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.001 * (effective_batch / 32), ... warmup_steps=2000 // accumulation_steps, # Adjust for accumulation ... total_steps=50000 // accumulation_steps, ... warmup_start_lr=1e-5, ... eta_min=1e-5 ... )
State persistence for checkpointing:
>>> # Save scheduler state >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.001, ... warmup_steps=1000, ... total_steps=10000 ... ) >>> # ... train for some steps ... >>> checkpoint = { ... 'step': current_step, ... 'scheduler': scheduler.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'model': model.state_dict() ... } >>> save(checkpoint, 'checkpoint.pkl') >>> >>> # Resume training >>> scheduler = braintools.optim.WarmupCosineSchedule( ... base_lr=0.001, ... warmup_steps=1000, ... total_steps=10000 ... ) >>> scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continue from saved step
See also
CosineAnnealingLRPure cosine annealing without warmup
LinearLRLinear learning rate schedule (can be used for warmup)
OneCycleLRAnother schedule combining warmup with annealing
PolynomialLRPolynomial decay schedule
References