SequentialLR#
- class braintools.optim.SequentialLR(schedulers, milestones, last_epoch=0)#
Sequential learning rate scheduler - Chains multiple schedulers based on epoch milestones.
SequentialLR allows you to chain multiple learning rate schedulers, with each scheduler being active during specific epoch ranges defined by milestones. This is particularly useful for complex training strategies that require different learning rate policies at different stages of training.
- Parameters:
schedulers (
List[LRScheduler]) – List of schedulers to be sequentially applied. The number of schedulers should belen(milestones) + 1.milestones (
List[int]) – List of epoch indices that define when to switch schedulers. Must be in ascending order. For n milestones, you need n+1 schedulers.last_epoch (
int) – The index of the last epoch. Default: 0.
Notes
Scheduler Switching Logic:
Given milestones [m1, m2, …, mn] and schedulers [s0, s1, …, sn]:
Epochs [0, m1): uses scheduler s0
Epochs [m1, m2): uses scheduler s1
…
Epochs [mn, ∞): uses scheduler sn
JIT Compatibility:
This implementation is JIT-compatible through the use of JAX operations for scheduler selection. The scheduler index is computed using
jnp.searchsortedfor efficient milestone-based switching.Important Considerations:
Each scheduler should be initialized with the appropriate
base_lrthat matches your intended learning rate at the transition point.The
last_epochparameter of individual schedulers is managed internally.When saving/loading state, all schedulers’ states are preserved.
Examples
Basic usage with warmup and decay:
>>> import braintools >>> import brainstate >>> >>> # Warmup for 5 epochs, then exponential decay >>> warmup = braintools.optim.LinearLR( ... base_lr=0.1, ... start_factor=0.01, ... end_factor=1.0, ... total_iters=5 ... ) >>> decay = braintools.optim.ExponentialLR( ... base_lr=0.1, ... gamma=0.95 ... ) >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[warmup, decay], ... milestones=[5] ... ) >>> >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... train(...) ... scheduler.step(epoch)
Three-phase training (warmup → training → fine-tuning):
>>> # Phase 1: Warmup (epochs 0-5) >>> warmup = braintools.optim.LinearLR( ... base_lr=0.001, ... start_factor=0.1, ... end_factor=1.0, ... total_iters=5 ... ) >>> >>> # Phase 2: Main training (epochs 5-80) >>> main_training = braintools.optim.CosineAnnealingLR( ... base_lr=0.001, ... T_max=75, ... eta_min=0.0001 ... ) >>> >>> # Phase 3: Fine-tuning (epochs 80+) >>> fine_tuning = braintools.optim.ConstantLR( ... base_lr=0.0001, ... factor=1.0 ... ) >>> >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[warmup, main_training, fine_tuning], ... milestones=[5, 80] ... )
Complex schedule for transformer training:
>>> # Transformer training schedule >>> # 1. Linear warmup >>> warmup = braintools.optim.LinearLR( ... base_lr=0.0005, ... start_factor=0.0, ... end_factor=1.0, ... total_iters=4000 # 4000 steps ... ) >>> >>> # 2. Constant learning rate >>> constant = braintools.optim.ConstantLR( ... base_lr=0.0005, ... factor=1.0 ... ) >>> >>> # 3. Cosine decay to near zero >>> cosine_decay = braintools.optim.CosineAnnealingLR( ... base_lr=0.0005, ... T_max=20000, ... eta_min=1e-6 ... ) >>> >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[warmup, constant, cosine_decay], ... milestones=[4000, 10000] ... )
State persistence across training sessions:
>>> # Save scheduler state >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[scheduler1, scheduler2], ... milestones=[50] ... ) >>> # ... train for some epochs ... >>> checkpoint = { ... 'epoch': epoch, ... 'scheduler': scheduler.state_dict(), ... 'optimizer': optimizer.state_dict(), ... } >>> save(checkpoint, 'checkpoint.pkl') >>> >>> # Resume training >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[scheduler1, scheduler2], ... milestones=[50] ... ) >>> scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continue training from saved epoch
Using with different optimizers:
>>> # Works with any optimizer >>> scheduler = braintools.optim.SequentialLR( ... schedulers=[warmup_sched, main_sched], ... milestones=[10] ... ) >>> >>> # With SGD >>> sgd_opt = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> >>> # With Adam >>> adam_opt = braintools.optim.Adam(lr=scheduler) >>> >>> # With LAMB for large batch training >>> lamb_opt = braintools.optim.LAMB(lr=scheduler)
Monitoring scheduler transitions:
>>> scheduler = braintools.optim.SequentialLR( ... schedulers=[sched1, sched2, sched3], ... milestones=[10, 20] ... ) >>> >>> for epoch in range(30): ... scheduler.step(epoch) ... current_lr = scheduler.get_lr() ... active_scheduler = scheduler.current_scheduler_idx ... print(f"Epoch {epoch}: LR={current_lr[0]:.6f}, " ... f"Active scheduler: {active_scheduler}") ... ... # Detect transitions ... if epoch in scheduler.milestones: ... print(f" -> Switching to scheduler {active_scheduler}")
See also
ChainedSchedulerApplies multiple schedulers simultaneously
LinearLRLinear learning rate schedule (good for warmup)
CosineAnnealingLRCosine annealing schedule
ExponentialLRExponential decay schedule
StepLRStep-wise decay schedule
References