WarmupScheduler#
- class braintools.optim.WarmupScheduler(base_lr=0.001, warmup_epochs=5, warmup_start_lr=0.0, last_epoch=0)#
Warmup learning rate scheduler - Linearly increases learning rate during warmup phase.
WarmupScheduler gradually increases the learning rate from a small initial value (warmup_start_lr) to the base learning rate over a specified number of warmup epochs. After the warmup period, the learning rate stays constant at the base learning rate. This is commonly used at the beginning of training to stabilize the optimization.
- Parameters:
base_lr (
float|List[float]) – Target learning rate(s) after warmup. Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.warmup_epochs (
int) – Number of epochs for the warmup phase. The learning rate will increase linearly from warmup_start_lr to base_lr over this many epochs.warmup_start_lr (
float) – Initial learning rate at the start of warmup. Default: 0.0.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\begin{split}\eta_t = \begin{cases} \eta_{\text{start}} + (\eta_{\text{base}} - \eta_{\text{start}}) \cdot \frac{t}{T_{\text{warmup}}} & \text{if } t < T_{\text{warmup}} \\ \eta_{\text{base}} & \text{otherwise} \end{cases}\end{split}\]where \(\eta_{\text{start}}\) is warmup_start_lr, \(\eta_{\text{base}}\) is base_lr, \(T_{\text{warmup}}\) is warmup_epochs, and \(t\) is the current epoch.
Key characteristics:
Linear warmup from small initial lr to target lr
Prevents instability from large initial gradients
Especially important for large batch training
Learning rate remains constant after warmup period
Common warmup configurations:
Short warmup: 5-10 epochs for standard training
Medium warmup: 10-20 epochs for large batch training
Long warmup: 30-50 epochs for very large batches or transformers
Start lr: Usually 0 or 0.01-0.1 * base_lr
When to use:
Training with large batch sizes (>256)
Training transformer models (BERT, GPT, ViT)
When model shows initial training instability
Fine-tuning with aggressive learning rates
Examples
Basic warmup:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> # Warmup from 0 to 0.1 over 10 epochs >>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.1, ... warmup_epochs=10, ... warmup_start_lr=0.0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(50): ... optimizer.step(grads) ... scheduler.step() # Epochs 0-9: lr increases linearly from 0 to 0.1 # Epochs 10+: lr stays at 0.1
Warmup with non-zero start:
>>> # Start from 10% of target lr >>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.01, ... warmup_epochs=5, ... warmup_start_lr=0.001 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(30): ... optimizer.step(grads) ... scheduler.step()
Large batch training:
>>> # Warmup for large batch size (1024+) >>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.4, # Linear scaling rule: 0.1 * (batch_size / 256) ... warmup_epochs=20, ... warmup_start_lr=0.0 ... ) >>> optimizer = braintools.optim.SGD( ... lr=scheduler, ... momentum=0.9, ... weight_decay=1e-4 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... train_epoch(model, optimizer, large_batch_loader) ... scheduler.step()
Transformer training warmup:
>>> # BERT-style warmup >>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.0001, ... warmup_epochs=10000, # Often in steps/iterations ... warmup_start_lr=0.0 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=0.01) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in range(100000): ... train_step(model, optimizer, batch) ... scheduler.step()
Warmup followed by decay (using ChainedScheduler):
>>> # Warmup then step decay >>> warmup = braintools.optim.WarmupScheduler( ... base_lr=0.1, ... warmup_epochs=5, ... warmup_start_lr=0.0 ... ) >>> decay = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step()
Short warmup for fine-tuning:
>>> # Gentle warmup for transfer learning >>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.0001, ... warmup_epochs=3, ... warmup_start_lr=0.00001 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(20): ... finetune_epoch(model, optimizer, finetune_loader) ... scheduler.step()
Vision Transformer training:
>>> # ViT warmup schedule >>> warmup = braintools.optim.WarmupScheduler( ... base_lr=0.001, ... warmup_epochs=10, ... warmup_start_lr=0.0 ... ) >>> cosine = braintools.optim.CosineAnnealingLR( ... base_lr=0.001, ... T_max=290, ... eta_min=1e-6 ... ) >>> # Use sequentially: warmup first, then cosine >>> optimizer = braintools.optim.AdamW(lr=warmup, weight_decay=0.05) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Warmup phase >>> for epoch in range(10): ... optimizer.step(grads) ... warmup.step() >>> >>> # Switch to cosine after warmup >>> cosine.attach_optimizer(optimizer) >>> for epoch in range(290): ... optimizer.step(grads) ... cosine.step()
State persistence:
>>> scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.1, ... warmup_epochs=10, ... warmup_start_lr=0.0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(5): ... scheduler.step() >>> >>> # Save checkpoint >>> checkpoint = { ... 'epoch': 5, ... 'scheduler': scheduler.state_dict(), ... } >>> >>> # Resume training >>> new_scheduler = braintools.optim.WarmupScheduler( ... base_lr=0.1, ... warmup_epochs=10, ... warmup_start_lr=0.0 ... ) >>> new_scheduler.load_state_dict(checkpoint['scheduler'])
Comparison with LinearLR:
>>> # WarmupScheduler: lr increases then stays constant >>> warmup_sched = braintools.optim.WarmupScheduler( ... base_lr=0.1, ... warmup_epochs=10, ... warmup_start_lr=0.0 ... ) >>> >>> # LinearLR: lr increases then CAN decrease or stay constant >>> linear_sched = braintools.optim.LinearLR( ... start_factor=0.0, ... end_factor=1.0, ... total_iters=10 ... ) >>> # Both achieve similar warmup, but LinearLR is more flexible
See also
LinearLRMore flexible linear scaling (can warmup or cooldown)
ConstantLRConstant factor multiplication
ChainedSchedulerCombine warmup with other schedules
References