SequentialLR#

class braintools.optim.SequentialLR(schedulers, milestones, last_epoch=0)#

Sequential learning rate scheduler - Chains multiple schedulers based on epoch milestones.

SequentialLR allows you to chain multiple learning rate schedulers, with each scheduler being active during specific epoch ranges defined by milestones. This is particularly useful for complex training strategies that require different learning rate policies at different stages of training.

Parameters:
  • schedulers (List[LRScheduler]) – List of schedulers to be sequentially applied. The number of schedulers should be len(milestones) + 1.

  • milestones (List[int]) – List of epoch indices that define when to switch schedulers. Must be in ascending order. For n milestones, you need n+1 schedulers.

  • last_epoch (int) – The index of the last epoch. Default: 0.

Notes

Scheduler Switching Logic:

Given milestones [m1, m2, …, mn] and schedulers [s0, s1, …, sn]:

  • Epochs [0, m1): uses scheduler s0

  • Epochs [m1, m2): uses scheduler s1

  • Epochs [mn, ∞): uses scheduler sn

JIT Compatibility:

This implementation is JIT-compatible through the use of JAX operations for scheduler selection. The scheduler index is computed using jnp.searchsorted for efficient milestone-based switching.

Important Considerations:

  1. Each scheduler should be initialized with the appropriate base_lr that matches your intended learning rate at the transition point.

  2. The last_epoch parameter of individual schedulers is managed internally.

  3. When saving/loading state, all schedulers’ states are preserved.

Examples

Basic usage with warmup and decay:

>>> import braintools
>>> import brainstate
>>>
>>> # Warmup for 5 epochs, then exponential decay
>>> warmup = braintools.optim.LinearLR(
...     base_lr=0.1,
...     start_factor=0.01,
...     end_factor=1.0,
...     total_iters=5
... )
>>> decay = braintools.optim.ExponentialLR(
...     base_lr=0.1,
...     gamma=0.95
... )
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[warmup, decay],
...     milestones=[5]
... )
>>>
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     train(...)
...     scheduler.step(epoch)

Three-phase training (warmup → training → fine-tuning):

>>> # Phase 1: Warmup (epochs 0-5)
>>> warmup = braintools.optim.LinearLR(
...     base_lr=0.001,
...     start_factor=0.1,
...     end_factor=1.0,
...     total_iters=5
... )
>>>
>>> # Phase 2: Main training (epochs 5-80)
>>> main_training = braintools.optim.CosineAnnealingLR(
...     base_lr=0.001,
...     T_max=75,
...     eta_min=0.0001
... )
>>>
>>> # Phase 3: Fine-tuning (epochs 80+)
>>> fine_tuning = braintools.optim.ConstantLR(
...     base_lr=0.0001,
...     factor=1.0
... )
>>>
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[warmup, main_training, fine_tuning],
...     milestones=[5, 80]
... )

Complex schedule for transformer training:

>>> # Transformer training schedule
>>> # 1. Linear warmup
>>> warmup = braintools.optim.LinearLR(
...     base_lr=0.0005,
...     start_factor=0.0,
...     end_factor=1.0,
...     total_iters=4000  # 4000 steps
... )
>>>
>>> # 2. Constant learning rate
>>> constant = braintools.optim.ConstantLR(
...     base_lr=0.0005,
...     factor=1.0
... )
>>>
>>> # 3. Cosine decay to near zero
>>> cosine_decay = braintools.optim.CosineAnnealingLR(
...     base_lr=0.0005,
...     T_max=20000,
...     eta_min=1e-6
... )
>>>
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[warmup, constant, cosine_decay],
...     milestones=[4000, 10000]
... )

State persistence across training sessions:

>>> # Save scheduler state
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[scheduler1, scheduler2],
...     milestones=[50]
... )
>>> # ... train for some epochs ...
>>> checkpoint = {
...     'epoch': epoch,
...     'scheduler': scheduler.state_dict(),
...     'optimizer': optimizer.state_dict(),
... }
>>> save(checkpoint, 'checkpoint.pkl')
>>>
>>> # Resume training
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[scheduler1, scheduler2],
...     milestones=[50]
... )
>>> scheduler.load_state_dict(checkpoint['scheduler'])
>>> # Continue training from saved epoch

Using with different optimizers:

>>> # Works with any optimizer
>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[warmup_sched, main_sched],
...     milestones=[10]
... )
>>>
>>> # With SGD
>>> sgd_opt = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>>
>>> # With Adam
>>> adam_opt = braintools.optim.Adam(lr=scheduler)
>>>
>>> # With LAMB for large batch training
>>> lamb_opt = braintools.optim.LAMB(lr=scheduler)

Monitoring scheduler transitions:

>>> scheduler = braintools.optim.SequentialLR(
...     schedulers=[sched1, sched2, sched3],
...     milestones=[10, 20]
... )
>>>
>>> for epoch in range(30):
...     scheduler.step(epoch)
...     current_lr = scheduler.get_lr()
...     active_scheduler = scheduler.current_scheduler_idx
...     print(f"Epoch {epoch}: LR={current_lr[0]:.6f}, "
...           f"Active scheduler: {active_scheduler}")
...
...     # Detect transitions
...     if epoch in scheduler.milestones:
...         print(f"  -> Switching to scheduler {active_scheduler}")

See also

ChainedScheduler

Applies multiple schedulers simultaneously

LinearLR

Linear learning rate schedule (good for warmup)

CosineAnnealingLR

Cosine annealing schedule

ExponentialLR

Exponential decay schedule

StepLR

Step-wise decay schedule

References

get_lr()[source]#

Get learning rate from the current scheduler.

load_state_dict(state_dict)[source]#

Load scheduler state from dictionary.

state_dict()[source]#

Return scheduler state as dictionary.

step(epoch=None)[source]#

Update learning rate.