CosineAnnealingLR#
- class braintools.optim.CosineAnnealingLR(base_lr=0.001, T_max=50, eta_min=0, last_epoch=0)#
Cosine annealing learning rate scheduler - Smoothly anneals learning rate using cosine function.
CosineAnnealingLR adjusts the learning rate following a cosine curve, starting from the initial learning rate and decreasing to a minimum value (eta_min) over T_max epochs. This provides a smooth, gradual decay that is popular for training deep neural networks.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.T_max (
int) – Maximum number of epochs for one annealing cycle. After T_max epochs, the learning rate reaches eta_min.eta_min (
float) – Minimum learning rate. The learning rate will decay from base_lr to eta_min over T_max epochs. Default: 0.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\eta_t = \eta_{\min} + \frac{1}{2}(\eta_0 - \eta_{\min}) \left(1 + \cos\left(\frac{t}{T_{\max}} \pi\right)\right)\]where \(\eta_0\) is the initial learning rate (base_lr), \(\eta_{\min}\) is the minimum learning rate, and \(T_{\max}\) is the maximum number of epochs.
Key characteristics:
Smooth cosine curve decay (no abrupt changes)
Learning rate starts high, decreases smoothly to eta_min
Most decay happens in the middle epochs
Popular for training vision models (ResNets, ViTs, etc.)
Often combined with warmup for best results
Decay pattern:
Early epochs (0-25% of T_max): Slow decay
Middle epochs (25-75% of T_max): Fast decay
Late epochs (75-100% of T_max): Slow decay approaching eta_min
When to use:
Training image classification models
When you want smooth learning rate transitions
Long training runs (100+ epochs)
Combined with warmup for transformer models
Examples
Basic cosine annealing:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> # Anneal from 0.1 to 0 over 100 epochs >>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=100, ... eta_min=0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... scheduler.step() ... if epoch % 25 == 0: ... print(f"Epoch {epoch}: lr = {optimizer.current_lr:.6f}") Epoch 0: lr = 0.100000 Epoch 25: lr = 0.085355 # Slow decay early Epoch 50: lr = 0.050000 # Fast decay middle Epoch 75: lr = 0.014645 # Slow decay late
With non-zero minimum learning rate:
>>> # Anneal from 0.01 to 0.0001 over 50 epochs >>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.01, ... T_max=50, ... eta_min=0.0001 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(50): ... optimizer.step(grads) ... scheduler.step()
Combined with warmup (recommended):
>>> # Warmup for 5 epochs, then cosine decay >>> warmup = braintools.optim.LinearLR( ... start_factor=0.01, ... end_factor=1.0, ... total_iters=5 ... ) >>> cosine = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=90, ... eta_min=0 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, cosine]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(95): ... optimizer.step(grads) ... scheduler.step()
CIFAR-10/100 training schedule:
>>> # Standard CIFAR schedule: 200 epochs with cosine decay >>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=200, ... eta_min=0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=5e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(200): ... optimizer.step(grads) ... scheduler.step()
ImageNet training with cosine decay:
>>> # ImageNet: 90 epochs with warmup + cosine >>> warmup = braintools.optim.LinearLR( ... start_factor=0.1, ... end_factor=1.0, ... total_iters=5 ... ) >>> cosine = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=85, ... eta_min=0 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, cosine]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step()
Fine-tuning with gentle cosine decay:
>>> # Gentle decay for fine-tuning: min lr = 10% of base lr >>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.0001, ... T_max=30, ... eta_min=0.00001 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(30): ... finetune_epoch(model, optimizer, finetune_loader) ... scheduler.step()
Saving and loading state:
>>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=100, ... eta_min=0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(50): ... scheduler.step() >>> >>> # Save checkpoint >>> checkpoint = { ... 'epoch': 50, ... 'model': model.state_dict(), ... 'scheduler': scheduler.state_dict(), ... } >>> >>> # Resume training >>> new_scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=100, ... eta_min=0 ... ) >>> new_scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continue from epoch 50 with correct lr
Vision Transformer training:
>>> # ViT training schedule >>> warmup = braintools.optim.LinearLR( ... start_factor=0.001, ... end_factor=1.0, ... total_iters=10 ... ) >>> cosine = braintools.optim.CosineAnnealingLR( ... base_lr=0.001, ... T_max=290, ... eta_min=1e-6 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, cosine]) >>> >>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.05) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(300): ... optimizer.step(grads) ... scheduler.step()
See also
CosineAnnealingWarmRestartsCosine annealing with periodic restarts
ExponentialLRExponential learning rate decay
LinearLRLinear learning rate warmup/cooldown
WarmupCosineScheduleIntegrated warmup + cosine schedule
References