ChainedScheduler#
- class braintools.optim.ChainedScheduler(schedulers)#
Chain multiple schedulers together - Applies multiple schedulers simultaneously.
ChainedScheduler allows you to apply multiple learning rate schedulers at the same time. All schedulers are stepped together at each epoch, and their effects are combined multiplicatively. This is particularly useful for implementing complex learning rate schedules like warmup followed by decay.
- Parameters:
schedulers (
List[LRScheduler]) – List of scheduler instances to chain together. All schedulers must operate on the same optimizer. The schedulers will be stepped in the order provided.
Notes
When multiple schedulers are chained:
Each scheduler computes its own learning rate adjustment
All schedulers are stepped simultaneously
The final learning rate is determined by the last scheduler in the chain
State management is handled individually for each scheduler
Key characteristics:
Enables complex multi-phase learning rate schedules
Common pattern: warmup + decay
All schedulers share the same epoch counter
Useful for combining complementary scheduling strategies
Common patterns:
Warmup + StepLR: Gradual increase followed by step decay
Warmup + CosineAnnealing: Linear warmup then smooth cosine decay
Multiple decay stages: ConstantLR + MultiStepLR
Examples
Warmup followed by step decay:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Create individual schedulers >>> warmup = braintools.optim.LinearLR( ... start_factor=0.1, ... end_factor=1.0, ... total_iters=5 ... ) >>> decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1) >>> >>> # Chain them together >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training loop >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step() # Epochs 0-4: warmup from 0.001 to 0.01 # Epochs 5-29: lr = 0.01 # Epochs 30-59: lr = 0.001 (first decay) # Epochs 60+: lr = 0.0001 (second decay)
Constant warmup + multi-step decay:
>>> # Start with reduced lr, then schedule decays >>> warmup = braintools.optim.ConstantLR(factor=0.1, total_iters=5) >>> decay = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 60, 80], ... gamma=0.1 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... scheduler.step()
Multiple warmup phases:
>>> # Two-stage warmup >>> warmup1 = braintools.optim.ConstantLR( ... base_lr=0.01, ... factor=0.01, ... total_iters=3 ... ) >>> warmup2 = braintools.optim.LinearLR( ... start_factor=0.1, ... end_factor=1.0, ... total_iters=7 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup1, warmup2]) >>> >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) # Epochs 0-2: lr = 0.0001 (constant low) # Epochs 3-9: lr increases from ~0.001 to 0.01 (linear) # Epochs 10+: lr = 0.01 (normal)
Saving and loading chained scheduler state:
>>> warmup = braintools.optim.LinearLR(start_factor=0.1, end_factor=1.0, total_iters=5) >>> decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(50): ... scheduler.step() >>> >>> # Save state >>> checkpoint = {'scheduler': scheduler.state_dict(), 'epoch': 50} >>> >>> # Later, resume training >>> new_warmup = braintools.optim.LinearLR(start_factor=0.1, end_factor=1.0, total_iters=5) >>> new_decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1) >>> new_scheduler = braintools.optim.ChainedScheduler([new_warmup, new_decay]) >>> new_scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continue from epoch 50
ImageNet-style training schedule:
>>> # Standard ImageNet: warmup + step decay >>> warmup = braintools.optim.LinearLR( ... start_factor=0.01, ... end_factor=1.0, ... total_iters=5 ... ) >>> decay = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 60], ... gamma=0.1 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.SGD( ... lr=scheduler, ... momentum=0.9, ... weight_decay=1e-4 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step()
Fine-tuning with conservative start:
>>> # Conservative warmup for transfer learning >>> warmup = braintools.optim.ConstantLR( ... base_lr=0.001, ... factor=0.1, ... total_iters=3 ... ) >>> decay = braintools.optim.MultiStepLR( ... base_lr=0.001, ... milestones=[10, 20], ... gamma=0.5 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(30): ... finetune_epoch(model, optimizer, finetune_loader) ... scheduler.step()
See also
SequentialLRSwitch between different schedulers at specific milestones
LinearLRLinear learning rate warmup/cooldown
StepLRStep learning rate decay
MultiStepLRMulti-step learning rate decay
References