CosineAnnealingWarmRestarts#
- class braintools.optim.CosineAnnealingWarmRestarts(base_lr=0.001, T_0=10, T_mult=1, eta_min=0, last_epoch=0)#
Cosine annealing with warm restarts - SGDR (Stochastic Gradient Descent with Warm Restarts).
CosineAnnealingWarmRestarts implements a learning rate schedule where the learning rate decreases following a cosine annealing schedule, then periodically restarts from the initial value. This creates a series of cosine waves with potentially increasing periods, allowing the model to escape local minima and explore different regions of the loss landscape.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). This is the maximum learning rate at the start of each cosine annealing cycle. Can be a single float or a list for multiple parameter groups. Default: 1e-3.T_0 (
int) – Number of epochs for the first restart cycle. This defines the initial period before the first restart. Default: 10.T_mult (
int) – Factor by which the period increases after each restart. If T_mult=1, all cycles have the same length. If T_mult=2, each cycle is twice as long as the previous. Default: 1.eta_min (
float) – Minimum learning rate. The learning rate will never go below this value during the cosine annealing. Default: 0.last_epoch (
int) – The index of the last epoch. Used when resuming training. Default: 0.
Notes
Mathematical Formulation:
Within each cosine annealing cycle, the learning rate follows:
\[\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 + \cos\left(\frac{T_{cur}}{T_i}\pi\right)\right)\]where: - \(\eta_{max}\) is the base learning rate - \(\eta_{min}\) is the minimum learning rate - \(T_{cur}\) is the number of epochs since the last restart - \(T_i\) is the current cycle length
Restart Schedule:
The cycle lengths follow the pattern: - First cycle: \(T_0\) epochs - Second cycle: \(T_0 \times T_{mult}\) epochs - Third cycle: \(T_0 \times T_{mult}^2\) epochs - And so on…
Benefits of Warm Restarts:
Escape Local Minima: Periodic restarts help the optimizer escape sharp minima
Ensemble Effect: Each restart produces a different model, creating an implicit ensemble
Fast Convergence: Combines rapid initial progress with fine-tuning
Exploration: Allows exploring different regions of the parameter space
JIT Compatibility:
This implementation is JIT-compatible through the use of jnp.where for conditional updates in the restart logic.
Examples
Basic usage with fixed-length cycles:
>>> import braintools >>> import brainstate >>> >>> # Restart every 50 epochs with same cycle length >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, ... T_0=50, ... T_mult=1, # Fixed cycle length ... eta_min=0.001 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(200): ... train_epoch(...) ... scheduler.step() ... # LR will restart at epochs 50, 100, 150
Increasing cycle lengths (recommended):
>>> # Cycles of increasing length: 10, 20, 40, 80, ... >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, ... T_0=10, ... T_mult=2, # Double cycle length each time ... eta_min=0.0001 ... ) >>> >>> # Training schedule: >>> # Epochs [0, 10): First cycle (10 epochs) >>> # Epochs [10, 30): Second cycle (20 epochs) >>> # Epochs [30, 70): Third cycle (40 epochs) >>> # And so on...
For transformer training:
>>> # Transformer models often benefit from warm restarts >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.0005, ... T_0=1000, # First cycle: 1000 steps ... T_mult=2, # Increasing cycles ... eta_min=1e-6 ... ) >>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.01) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Fine-tuning with short cycles:
>>> # Fine-tuning with frequent restarts for exploration >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.001, # Lower LR for fine-tuning ... T_0=5, # Short initial cycle ... T_mult=1, # Keep cycles short ... eta_min=1e-5 ... ) >>> >>> # This creates rapid oscillations for better exploration >>> for epoch in range(50): ... fine_tune_epoch(...) ... scheduler.step()
Snapshot ensembling:
>>> # Save model at each restart for ensemble >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, ... T_0=25, ... T_mult=1, ... eta_min=0 ... ) >>> >>> snapshots = [] >>> for epoch in range(100): ... train_epoch(...) ... scheduler.step() ... ... # Save snapshot at minimum LR (just before restart) ... if scheduler.T_cur.value == scheduler.T_i.value - 1: ... snapshot = copy.deepcopy(model.state_dict()) ... snapshots.append(snapshot) ... print(f"Saved snapshot at epoch {epoch}")
Monitoring restarts:
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, ... T_0=10, ... T_mult=2, ... eta_min=0.001 ... ) >>> >>> for epoch in range(100): ... old_T_cur = scheduler.T_cur.value ... scheduler.step() ... current_lr = scheduler.get_lr()[0] ... ... # Detect restart ... if scheduler.T_cur.value < old_T_cur.value: ... print(f"Restart at epoch {epoch}! LR reset to {current_lr:.6f}") ... ... if epoch % 10 == 0: ... print(f"Epoch {epoch}: LR={current_lr:.6f}, " ... f"T_cur={scheduler.T_cur.value}, T_i={scheduler.T_i.value}")
Custom restart schedule with T_mult > 1:
>>> # Create a schedule with specific restart points >>> # Restarts at: 0, 100, 300, 700, 1500, ... >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.01, ... T_0=100, # First cycle: 100 epochs ... T_mult=2, # Each cycle doubles ... eta_min=0.0001 ... ) >>> >>> # Calculate when restarts will occur >>> def get_restart_epochs(T_0, T_mult, n_restarts): ... epochs = [0] ... T_i = T_0 ... for i in range(n_restarts): ... epochs.append(epochs[-1] + T_i) ... T_i = T_i * T_mult ... return epochs >>> >>> restart_epochs = get_restart_epochs(100, 2, 5) >>> print(f"Restarts at epochs: {restart_epochs}") >>> # Output: [0, 100, 300, 700, 1500, 3100]
Combining with other techniques:
>>> # Combine with gradient clipping and weight decay >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, ... T_0=30, ... T_mult=2, ... eta_min=0.001 ... ) >>> >>> optimizer = braintools.optim.AdamW( ... lr=scheduler, ... weight_decay=1e-4, ... clip_norm=1.0 # Gradient clipping ... )
State persistence for long training:
>>> # Save and restore scheduler state >>> scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, T_0=50, T_mult=2 ... ) >>> >>> # Train for some epochs >>> for epoch in range(75): ... scheduler.step() >>> >>> # Save state >>> state = { ... 'epoch': 75, ... 'scheduler': scheduler.state_dict(), ... 'T_cur': scheduler.T_cur, ... 'T_i': scheduler.T_i ... } >>> >>> # Later: restore and continue >>> new_scheduler = braintools.optim.CosineAnnealingWarmRestarts( ... base_lr=0.1, T_0=50, T_mult=2 ... ) >>> new_scheduler.load_state_dict(state['scheduler']) >>> new_scheduler.T_cur = state['T_cur'] >>> new_scheduler.T_i = state['T_i']
See also
CosineAnnealingLRStandard cosine annealing without restarts
OneCycleLROne cycle learning rate policy
CyclicLRCyclic learning rates between bounds
ChainedSchedulerChain multiple schedulers together
References