ChainedScheduler#

class braintools.optim.ChainedScheduler(schedulers)#

Chain multiple schedulers together - Applies multiple schedulers simultaneously.

ChainedScheduler allows you to apply multiple learning rate schedulers at the same time. All schedulers are stepped together at each epoch, and their effects are combined multiplicatively. This is particularly useful for implementing complex learning rate schedules like warmup followed by decay.

Parameters:

schedulers (List[LRScheduler]) – List of scheduler instances to chain together. All schedulers must operate on the same optimizer. The schedulers will be stepped in the order provided.

Notes

When multiple schedulers are chained:

  • Each scheduler computes its own learning rate adjustment

  • All schedulers are stepped simultaneously

  • The final learning rate is determined by the last scheduler in the chain

  • State management is handled individually for each scheduler

Key characteristics:

  • Enables complex multi-phase learning rate schedules

  • Common pattern: warmup + decay

  • All schedulers share the same epoch counter

  • Useful for combining complementary scheduling strategies

Common patterns:

  • Warmup + StepLR: Gradual increase followed by step decay

  • Warmup + CosineAnnealing: Linear warmup then smooth cosine decay

  • Multiple decay stages: ConstantLR + MultiStepLR

Examples

Warmup followed by step decay:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Create individual schedulers
>>> warmup = braintools.optim.LinearLR(
...     start_factor=0.1,
...     end_factor=1.0,
...     total_iters=5
... )
>>> decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1)
>>>
>>> # Chain them together
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training loop
>>> for epoch in range(90):
...     optimizer.step(grads)
...     scheduler.step()
# Epochs 0-4:   warmup from 0.001 to 0.01
# Epochs 5-29:  lr = 0.01
# Epochs 30-59: lr = 0.001 (first decay)
# Epochs 60+:   lr = 0.0001 (second decay)

Constant warmup + multi-step decay:

>>> # Start with reduced lr, then schedule decays
>>> warmup = braintools.optim.ConstantLR(factor=0.1, total_iters=5)
>>> decay = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 60, 80],
...     gamma=0.1
... )
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     optimizer.step(grads)
...     scheduler.step()

Multiple warmup phases:

>>> # Two-stage warmup
>>> warmup1 = braintools.optim.ConstantLR(
...     base_lr=0.01,
...     factor=0.01,
...     total_iters=3
... )
>>> warmup2 = braintools.optim.LinearLR(
...     start_factor=0.1,
...     end_factor=1.0,
...     total_iters=7
... )
>>> scheduler = braintools.optim.ChainedScheduler([warmup1, warmup2])
>>>
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
# Epochs 0-2:   lr = 0.0001 (constant low)
# Epochs 3-9:   lr increases from ~0.001 to 0.01 (linear)
# Epochs 10+:   lr = 0.01 (normal)

Saving and loading chained scheduler state:

>>> warmup = braintools.optim.LinearLR(start_factor=0.1, end_factor=1.0, total_iters=5)
>>> decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1)
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Train for some epochs
>>> for epoch in range(50):
...     scheduler.step()
>>>
>>> # Save state
>>> checkpoint = {'scheduler': scheduler.state_dict(), 'epoch': 50}
>>>
>>> # Later, resume training
>>> new_warmup = braintools.optim.LinearLR(start_factor=0.1, end_factor=1.0, total_iters=5)
>>> new_decay = braintools.optim.StepLR(base_lr=0.01, step_size=30, gamma=0.1)
>>> new_scheduler = braintools.optim.ChainedScheduler([new_warmup, new_decay])
>>> new_scheduler.load_state_dict(checkpoint['scheduler'])
>>> # Continue from epoch 50

ImageNet-style training schedule:

>>> # Standard ImageNet: warmup + step decay
>>> warmup = braintools.optim.LinearLR(
...     start_factor=0.01,
...     end_factor=1.0,
...     total_iters=5
... )
>>> decay = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 60],
...     gamma=0.1
... )
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.SGD(
...     lr=scheduler,
...     momentum=0.9,
...     weight_decay=1e-4
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(90):
...     optimizer.step(grads)
...     scheduler.step()

Fine-tuning with conservative start:

>>> # Conservative warmup for transfer learning
>>> warmup = braintools.optim.ConstantLR(
...     base_lr=0.001,
...     factor=0.1,
...     total_iters=3
... )
>>> decay = braintools.optim.MultiStepLR(
...     base_lr=0.001,
...     milestones=[10, 20],
...     gamma=0.5
... )
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(30):
...     finetune_epoch(model, optimizer, finetune_loader)
...     scheduler.step()

See also

SequentialLR

Switch between different schedulers at specific milestones

LinearLR

Linear learning rate warmup/cooldown

StepLR

Step learning rate decay

MultiStepLR

Multi-step learning rate decay

References

attach_optimizer(optimizer)[source]#

Attach optimizer to all schedulers.

get_lr()[source]#

Calculate learning rate.

load_state_dict(state_dict)[source]#

Load scheduler state from dictionary.

state_dict()[source]#

Return scheduler state as dictionary.

step(*args, **kwargs)[source]#

Update learning rate.