CosineAnnealingWarmRestarts#

class braintools.optim.CosineAnnealingWarmRestarts(base_lr=0.001, T_0=10, T_mult=1, eta_min=0, last_epoch=0)#

Cosine annealing with warm restarts - SGDR (Stochastic Gradient Descent with Warm Restarts).

CosineAnnealingWarmRestarts implements a learning rate schedule where the learning rate decreases following a cosine annealing schedule, then periodically restarts from the initial value. This creates a series of cosine waves with potentially increasing periods, allowing the model to escape local minima and explore different regions of the loss landscape.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). This is the maximum learning rate at the start of each cosine annealing cycle. Can be a single float or a list for multiple parameter groups. Default: 1e-3.

  • T_0 (int) – Number of epochs for the first restart cycle. This defines the initial period before the first restart. Default: 10.

  • T_mult (int) – Factor by which the period increases after each restart. If T_mult=1, all cycles have the same length. If T_mult=2, each cycle is twice as long as the previous. Default: 1.

  • eta_min (float) – Minimum learning rate. The learning rate will never go below this value during the cosine annealing. Default: 0.

  • last_epoch (int) – The index of the last epoch. Used when resuming training. Default: 0.

Notes

Mathematical Formulation:

Within each cosine annealing cycle, the learning rate follows:

\[\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 + \cos\left(\frac{T_{cur}}{T_i}\pi\right)\right)\]

where: - \(\eta_{max}\) is the base learning rate - \(\eta_{min}\) is the minimum learning rate - \(T_{cur}\) is the number of epochs since the last restart - \(T_i\) is the current cycle length

Restart Schedule:

The cycle lengths follow the pattern: - First cycle: \(T_0\) epochs - Second cycle: \(T_0 \times T_{mult}\) epochs - Third cycle: \(T_0 \times T_{mult}^2\) epochs - And so on…

Benefits of Warm Restarts:

  1. Escape Local Minima: Periodic restarts help the optimizer escape sharp minima

  2. Ensemble Effect: Each restart produces a different model, creating an implicit ensemble

  3. Fast Convergence: Combines rapid initial progress with fine-tuning

  4. Exploration: Allows exploring different regions of the parameter space

JIT Compatibility:

This implementation is JIT-compatible through the use of jnp.where for conditional updates in the restart logic.

Examples

Basic usage with fixed-length cycles:

>>> import braintools
>>> import brainstate
>>>
>>> # Restart every 50 epochs with same cycle length
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1,
...     T_0=50,
...     T_mult=1,  # Fixed cycle length
...     eta_min=0.001
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(200):
...     train_epoch(...)
...     scheduler.step()
...     # LR will restart at epochs 50, 100, 150

Increasing cycle lengths (recommended):

>>> # Cycles of increasing length: 10, 20, 40, 80, ...
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1,
...     T_0=10,
...     T_mult=2,  # Double cycle length each time
...     eta_min=0.0001
... )
>>>
>>> # Training schedule:
>>> # Epochs [0, 10): First cycle (10 epochs)
>>> # Epochs [10, 30): Second cycle (20 epochs)
>>> # Epochs [30, 70): Third cycle (40 epochs)
>>> # And so on...

For transformer training:

>>> # Transformer models often benefit from warm restarts
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.0005,
...     T_0=1000,  # First cycle: 1000 steps
...     T_mult=2,   # Increasing cycles
...     eta_min=1e-6
... )
>>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Fine-tuning with short cycles:

>>> # Fine-tuning with frequent restarts for exploration
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.001,  # Lower LR for fine-tuning
...     T_0=5,          # Short initial cycle
...     T_mult=1,       # Keep cycles short
...     eta_min=1e-5
... )
>>>
>>> # This creates rapid oscillations for better exploration
>>> for epoch in range(50):
...     fine_tune_epoch(...)
...     scheduler.step()

Snapshot ensembling:

>>> # Save model at each restart for ensemble
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1,
...     T_0=25,
...     T_mult=1,
...     eta_min=0
... )
>>>
>>> snapshots = []
>>> for epoch in range(100):
...     train_epoch(...)
...     scheduler.step()
...
...     # Save snapshot at minimum LR (just before restart)
...     if scheduler.T_cur.value == scheduler.T_i.value - 1:
...         snapshot = copy.deepcopy(model.state_dict())
...         snapshots.append(snapshot)
...         print(f"Saved snapshot at epoch {epoch}")

Monitoring restarts:

>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1,
...     T_0=10,
...     T_mult=2,
...     eta_min=0.001
... )
>>>
>>> for epoch in range(100):
...     old_T_cur = scheduler.T_cur.value
...     scheduler.step()
...     current_lr = scheduler.get_lr()[0]
...
...     # Detect restart
...     if scheduler.T_cur.value < old_T_cur.value:
...         print(f"Restart at epoch {epoch}! LR reset to {current_lr:.6f}")
...
...     if epoch % 10 == 0:
...         print(f"Epoch {epoch}: LR={current_lr:.6f}, "
...               f"T_cur={scheduler.T_cur.value}, T_i={scheduler.T_i.value}")

Custom restart schedule with T_mult > 1:

>>> # Create a schedule with specific restart points
>>> # Restarts at: 0, 100, 300, 700, 1500, ...
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.01,
...     T_0=100,    # First cycle: 100 epochs
...     T_mult=2,   # Each cycle doubles
...     eta_min=0.0001
... )
>>>
>>> # Calculate when restarts will occur
>>> def get_restart_epochs(T_0, T_mult, n_restarts):
...     epochs = [0]
...     T_i = T_0
...     for i in range(n_restarts):
...         epochs.append(epochs[-1] + T_i)
...         T_i = T_i * T_mult
...     return epochs
>>>
>>> restart_epochs = get_restart_epochs(100, 2, 5)
>>> print(f"Restarts at epochs: {restart_epochs}")
>>> # Output: [0, 100, 300, 700, 1500, 3100]

Combining with other techniques:

>>> # Combine with gradient clipping and weight decay
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1,
...     T_0=30,
...     T_mult=2,
...     eta_min=0.001
... )
>>>
>>> optimizer = braintools.optim.AdamW(
...     lr=scheduler,
...     weight_decay=1e-4,
...     clip_norm=1.0  # Gradient clipping
... )

State persistence for long training:

>>> # Save and restore scheduler state
>>> scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1, T_0=50, T_mult=2
... )
>>>
>>> # Train for some epochs
>>> for epoch in range(75):
...     scheduler.step()
>>>
>>> # Save state
>>> state = {
...     'epoch': 75,
...     'scheduler': scheduler.state_dict(),
...     'T_cur': scheduler.T_cur,
...     'T_i': scheduler.T_i
... }
>>>
>>> # Later: restore and continue
>>> new_scheduler = braintools.optim.CosineAnnealingWarmRestarts(
...     base_lr=0.1, T_0=50, T_mult=2
... )
>>> new_scheduler.load_state_dict(state['scheduler'])
>>> new_scheduler.T_cur = state['T_cur']
>>> new_scheduler.T_i = state['T_i']

See also

CosineAnnealingLR

Standard cosine annealing without restarts

OneCycleLR

One cycle learning rate policy

CyclicLR

Cyclic learning rates between bounds

ChainedScheduler

Chain multiple schedulers together

References

get_lr()[source]#

Calculate learning rate.

step(epoch=None)[source]#

Update learning rate.