WarmupScheduler

WarmupScheduler#

class braintools.optim.WarmupScheduler(base_lr=0.001, warmup_epochs=5, warmup_start_lr=0.0, last_epoch=0)#

Warmup learning rate scheduler - Linearly increases learning rate during warmup phase.

WarmupScheduler gradually increases the learning rate from a small initial value (warmup_start_lr) to the base learning rate over a specified number of warmup epochs. After the warmup period, the learning rate stays constant at the base learning rate. This is commonly used at the beginning of training to stabilize the optimization.

Parameters:
  • base_lr (float | List[float]) – Target learning rate(s) after warmup. Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • warmup_epochs (int) – Number of epochs for the warmup phase. The learning rate will increase linearly from warmup_start_lr to base_lr over this many epochs.

  • warmup_start_lr (float) – Initial learning rate at the start of warmup. Default: 0.0.

  • last_epoch (int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).

Notes

The learning rate at epoch \(t\) is computed as:

\[\begin{split}\eta_t = \begin{cases} \eta_{\text{start}} + (\eta_{\text{base}} - \eta_{\text{start}}) \cdot \frac{t}{T_{\text{warmup}}} & \text{if } t < T_{\text{warmup}} \\ \eta_{\text{base}} & \text{otherwise} \end{cases}\end{split}\]

where \(\eta_{\text{start}}\) is warmup_start_lr, \(\eta_{\text{base}}\) is base_lr, \(T_{\text{warmup}}\) is warmup_epochs, and \(t\) is the current epoch.

Key characteristics:

  • Linear warmup from small initial lr to target lr

  • Prevents instability from large initial gradients

  • Especially important for large batch training

  • Learning rate remains constant after warmup period

Common warmup configurations:

  • Short warmup: 5-10 epochs for standard training

  • Medium warmup: 10-20 epochs for large batch training

  • Long warmup: 30-50 epochs for very large batches or transformers

  • Start lr: Usually 0 or 0.01-0.1 * base_lr

When to use:

  • Training with large batch sizes (>256)

  • Training transformer models (BERT, GPT, ViT)

  • When model shows initial training instability

  • Fine-tuning with aggressive learning rates

Examples

Basic warmup:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> # Warmup from 0 to 0.1 over 10 epochs
>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.1,
...     warmup_epochs=10,
...     warmup_start_lr=0.0
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(50):
...     optimizer.step(grads)
...     scheduler.step()
# Epochs 0-9: lr increases linearly from 0 to 0.1
# Epochs 10+: lr stays at 0.1

Warmup with non-zero start:

>>> # Start from 10% of target lr
>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.01,
...     warmup_epochs=5,
...     warmup_start_lr=0.001
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(30):
...     optimizer.step(grads)
...     scheduler.step()

Large batch training:

>>> # Warmup for large batch size (1024+)
>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.4,  # Linear scaling rule: 0.1 * (batch_size / 256)
...     warmup_epochs=20,
...     warmup_start_lr=0.0
... )
>>> optimizer = braintools.optim.SGD(
...     lr=scheduler,
...     momentum=0.9,
...     weight_decay=1e-4
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     train_epoch(model, optimizer, large_batch_loader)
...     scheduler.step()

Transformer training warmup:

>>> # BERT-style warmup
>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.0001,
...     warmup_epochs=10000,  # Often in steps/iterations
...     warmup_start_lr=0.0
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in range(100000):
...     train_step(model, optimizer, batch)
...     scheduler.step()

Warmup followed by decay (using ChainedScheduler):

>>> # Warmup then step decay
>>> warmup = braintools.optim.WarmupScheduler(
...     base_lr=0.1,
...     warmup_epochs=5,
...     warmup_start_lr=0.0
... )
>>> decay = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(90):
...     optimizer.step(grads)
...     scheduler.step()

Short warmup for fine-tuning:

>>> # Gentle warmup for transfer learning
>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.0001,
...     warmup_epochs=3,
...     warmup_start_lr=0.00001
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(20):
...     finetune_epoch(model, optimizer, finetune_loader)
...     scheduler.step()

Vision Transformer training:

>>> # ViT warmup schedule
>>> warmup = braintools.optim.WarmupScheduler(
...     base_lr=0.001,
...     warmup_epochs=10,
...     warmup_start_lr=0.0
... )
>>> cosine = braintools.optim.CosineAnnealingLR(
...     base_lr=0.001,
...     T_max=290,
...     eta_min=1e-6
... )
>>> # Use sequentially: warmup first, then cosine
>>> optimizer = braintools.optim.AdamW(lr=warmup, weight_decay=0.05)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Warmup phase
>>> for epoch in range(10):
...     optimizer.step(grads)
...     warmup.step()
>>>
>>> # Switch to cosine after warmup
>>> cosine.attach_optimizer(optimizer)
>>> for epoch in range(290):
...     optimizer.step(grads)
...     cosine.step()

State persistence:

>>> scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.1,
...     warmup_epochs=10,
...     warmup_start_lr=0.0
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Train for some epochs
>>> for epoch in range(5):
...     scheduler.step()
>>>
>>> # Save checkpoint
>>> checkpoint = {
...     'epoch': 5,
...     'scheduler': scheduler.state_dict(),
... }
>>>
>>> # Resume training
>>> new_scheduler = braintools.optim.WarmupScheduler(
...     base_lr=0.1,
...     warmup_epochs=10,
...     warmup_start_lr=0.0
... )
>>> new_scheduler.load_state_dict(checkpoint['scheduler'])

Comparison with LinearLR:

>>> # WarmupScheduler: lr increases then stays constant
>>> warmup_sched = braintools.optim.WarmupScheduler(
...     base_lr=0.1,
...     warmup_epochs=10,
...     warmup_start_lr=0.0
... )
>>>
>>> # LinearLR: lr increases then CAN decrease or stay constant
>>> linear_sched = braintools.optim.LinearLR(
...     start_factor=0.0,
...     end_factor=1.0,
...     total_iters=10
... )
>>> # Both achieve similar warmup, but LinearLR is more flexible

See also

LinearLR

More flexible linear scaling (can warmup or cooldown)

ConstantLR

Constant factor multiplication

ChainedScheduler

Combine warmup with other schedules

References

get_lr()[source]#

Calculate learning rate.