WarmupCosineSchedule

WarmupCosineSchedule#

class braintools.optim.WarmupCosineSchedule(base_lr=0.001, warmup_steps=1000, total_steps=10000, warmup_start_lr=0.0, eta_min=0.0, last_epoch=0)#

Warmup + Cosine annealing schedule for smooth training transitions.

WarmupCosineSchedule combines linear warmup with cosine annealing to create a smooth learning rate schedule that’s particularly effective for training deep neural networks from scratch. The schedule starts with a low learning rate, linearly increases to the base rate during warmup, then follows a cosine decay to a minimum value.

This scheduler is widely used in: - Vision Transformers (ViT) and other transformer architectures - Self-supervised learning (SimCLR, BYOL, MAE) - Large-scale distributed training - Fine-tuning pre-trained models

Parameters:
  • base_lr (float | List[float]) – Peak learning rate(s) reached at the end of warmup. This is the maximum learning rate in the schedule. Can be a single float or list for multiple parameter groups. Default: 1e-3.

  • warmup_steps (int) – Number of steps for the linear warmup phase. During this phase, the learning rate linearly increases from warmup_start_lr to base_lr. Default: 1000.

  • total_steps (int) – Total number of training steps. The cosine annealing phase spans from warmup_steps to total_steps. Default: 10000.

  • warmup_start_lr (float) – Starting learning rate for the warmup phase. Set to 0 for linear warmup from zero, or a small value (e.g., 1e-6) for stability. Default: 0.0.

  • eta_min (float) – Minimum learning rate at the end of cosine annealing. The learning rate will not go below this value. Default: 0.0.

  • last_epoch (int) – The index of the last epoch. Used when resuming training. Default: 0.

Notes

Mathematical Formulation:

The learning rate schedule consists of two phases:

  1. Linear Warmup Phase (step < warmup_steps):

\[\eta_t = \eta_{warmup_start} + \frac{t}{T_{warmup}} \cdot (\eta_{base} - \eta_{warmup_start})\]
  1. Cosine Annealing Phase (step >= warmup_steps):

\[\eta_t = \eta_{min} + \frac{1}{2}(\eta_{base} - \eta_{min}) \cdot \left(1 + \cos\left(\pi \cdot \frac{t - T_{warmup}} {T_{total} - T_{warmup}}\right)\right)\]

where: - \(t\) is the current step - \(T_{warmup}\) is the number of warmup steps - \(T_{total}\) is the total number of steps - \(\eta_{base}\) is the peak learning rate - \(\eta_{min}\) is the minimum learning rate

Benefits of Warmup:

  1. Stability: Prevents divergence in early training with large learning rates

  2. Gradient Statistics: Allows optimizer to gather statistics before full LR

  3. Weight Initialization: Gives random weights time to organize

  4. Large Batch Training: Essential for stable training with large batches

JIT Compatibility:

This implementation is fully JIT-compatible through the use of JAX operations and conditional selection with jnp.where.

Examples

Basic Vision Transformer training:

>>> import braintools
>>> import brainstate
>>>
>>> # Standard ViT training schedule
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.001,           # Peak learning rate
...     warmup_steps=10000,      # 10k warmup steps
...     total_steps=100000,      # 100k total steps
...     warmup_start_lr=1e-6,    # Start from small LR
...     eta_min=1e-5             # End at small LR
... )
>>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.05)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in range(100000):
...     loss = train_step(...)
...     scheduler.step()
...     if step % 1000 == 0:
...         lr = scheduler.get_lr()[0]
...         print(f"Step {step}: LR = {lr:.6f}")

Self-supervised learning (SimCLR/BYOL style):

>>> # Self-supervised learning benefits from longer warmup
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.3,             # High LR for contrastive learning
...     warmup_steps=4000,       # ~10 epochs warmup
...     total_steps=40000,       # 100 epochs total
...     warmup_start_lr=0.0,     # Start from zero
...     eta_min=0.0              # Decay to zero
... )
>>>
>>> # Scale learning rate with batch size (linear scaling rule)
>>> batch_size = 4096
>>> base_batch_size = 256
>>> scaled_lr = 0.3 * (batch_size / base_batch_size)
>>> scheduler.base_lrs = [scaled_lr]

Fine-tuning pre-trained models:

>>> # Shorter warmup for fine-tuning
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=5e-5,            # Lower LR for fine-tuning
...     warmup_steps=500,        # Quick warmup
...     total_steps=10000,       # Shorter total training
...     warmup_start_lr=0.0,
...     eta_min=1e-6
... )
>>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.01)

Distributed training with large batches:

>>> # Large batch training needs careful warmup
>>> world_size = 8  # 8 GPUs
>>> batch_per_gpu = 64
>>> total_batch = batch_per_gpu * world_size  # 512
>>>
>>> # Linear scaling rule with warmup
>>> base_lr = 0.001
>>> scaled_lr = base_lr * (total_batch / 128)  # Scale from base batch 128
>>>
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=scaled_lr,
...     warmup_steps=2000,       # Longer warmup for large batch
...     total_steps=50000,
...     warmup_start_lr=base_lr / 100,  # Start at 1% of base
...     eta_min=scaled_lr * 0.01
... )

MAE-style masked autoencoder training:

>>> # MAE uses specific warmup schedule
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=1.5e-4,          # Base LR for batch 4096
...     warmup_steps=40 * 312,   # 40 epochs of warmup
...     total_steps=1600 * 312,  # 1600 epochs total
...     warmup_start_lr=0.0,
...     eta_min=0.0
... )
>>>
>>> # Combined with specific optimizer settings
>>> optimizer = braintools.optim.AdamW(
...     lr=scheduler,
...     betas=(0.9, 0.95),  # MAE-specific betas
...     weight_decay=0.05
... )

BERT-style transformer training:

>>> # BERT uses fraction-based warmup
>>> total_steps = 1000000
>>> warmup_fraction = 0.1  # 10% warmup
>>>
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=1e-4,
...     warmup_steps=int(total_steps * warmup_fraction),
...     total_steps=total_steps,
...     warmup_start_lr=0.0,
...     eta_min=1e-5
... )

Monitoring warmup progression:

>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.001,
...     warmup_steps=1000,
...     total_steps=10000,
...     warmup_start_lr=1e-5,
...     eta_min=1e-4
... )
>>>
>>> for step in range(10000):
...     scheduler.step()
...     lr = scheduler.get_lr()[0]
...
...     # Track phase transitions
...     if step == 0:
...         print(f"Starting warmup from LR = {lr:.6f}")
...     elif step == scheduler.warmup_steps - 1:
...         print(f"Ending warmup at LR = {lr:.6f}")
...     elif step == scheduler.warmup_steps:
...         print(f"Starting cosine decay from LR = {lr:.6f}")
...     elif step == scheduler.total_steps - 1:
...         print(f"Training complete at LR = {lr:.6f}")

Custom warmup strategies:

>>> # Aggressive warmup (reach peak quickly)
>>> fast_warmup = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.01,
...     warmup_steps=100,        # Very short warmup
...     total_steps=10000,
...     warmup_start_lr=0.001,   # Start at 10% of peak
...     eta_min=0.0001
... )
>>>
>>> # Conservative warmup (gradual increase)
>>> slow_warmup = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.01,
...     warmup_steps=5000,       # 50% of training for warmup
...     total_steps=10000,
...     warmup_start_lr=1e-6,    # Start very low
...     eta_min=0.001
... )

Combining with gradient accumulation:

>>> # Gradient accumulation affects effective batch size
>>> accumulation_steps = 4
>>> per_step_batch = 32
>>> effective_batch = per_step_batch * accumulation_steps  # 128
>>>
>>> # Adjust learning rate and warmup accordingly
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.001 * (effective_batch / 32),
...     warmup_steps=2000 // accumulation_steps,  # Adjust for accumulation
...     total_steps=50000 // accumulation_steps,
...     warmup_start_lr=1e-5,
...     eta_min=1e-5
... )

State persistence for checkpointing:

>>> # Save scheduler state
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.001,
...     warmup_steps=1000,
...     total_steps=10000
... )
>>> # ... train for some steps ...
>>> checkpoint = {
...     'step': current_step,
...     'scheduler': scheduler.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'model': model.state_dict()
... }
>>> save(checkpoint, 'checkpoint.pkl')
>>>
>>> # Resume training
>>> scheduler = braintools.optim.WarmupCosineSchedule(
...     base_lr=0.001,
...     warmup_steps=1000,
...     total_steps=10000
... )
>>> scheduler.load_state_dict(checkpoint['scheduler'])
>>> # Continue from saved step

See also

CosineAnnealingLR

Pure cosine annealing without warmup

LinearLR

Linear learning rate schedule (can be used for warmup)

OneCycleLR

Another schedule combining warmup with annealing

PolynomialLR

Polynomial decay schedule

References

get_lr()[source]#

Calculate learning rate.