MultiStepLR#
- class braintools.optim.MultiStepLR(base_lr=0.001, milestones=(30, 60, 90), gamma=0.1, last_epoch=0)#
Multi-step learning rate scheduler - Decays learning rate at specific milestone epochs.
MultiStepLR reduces the learning rate by a factor of gamma at each epoch specified in the milestones list. This provides more flexible control than StepLR, allowing you to schedule learning rate drops at arbitrary points during training.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.milestones (
Sequence[int]) – List of epoch indices at which to decay the learning rate. Must be increasing. Default: (30, 60, 90).gamma (
float) – Multiplicative factor of learning rate decay. Must be in range (0, 1]. Default: 0.1.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\eta_t = \eta_0 \cdot \gamma^{|\{m \in \text{milestones} : m \leq t\}|}\]where \(\eta_0\) is the initial learning rate (base_lr), and \(|\{m \in \text{milestones} : m \leq t\}|\) counts how many milestones have been reached by epoch \(t\).
Key characteristics:
Provides precise control over when learning rate changes occur
Ideal when you know specific epochs where model learning plateaus
Commonly used in research papers with fixed training schedules
Each milestone multiplies the current lr by gamma
Common milestone patterns:
ImageNet (90 epochs): milestones=[30, 60], gamma=0.1
CIFAR (200 epochs): milestones=[60, 120, 160], gamma=0.2
Fine-tuning: milestones=[10, 20], gamma=0.5
Examples
Basic usage with predefined milestones:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 80], ... gamma=0.1 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # lr schedule: >>> # epochs 0-29: lr = 0.1 >>> # epochs 30-79: lr = 0.01 (after 1st milestone) >>> # epochs 80+: lr = 0.001 (after 2nd milestone)
Using with Adam for fine-tuning:
>>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.001, ... milestones=[10, 20, 30], ... gamma=0.5 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(40): ... # Training code ... scheduler.step() # lr: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20) -> 0.000125 (epoch 30)
ImageNet-style training schedule:
>>> # Standard ImageNet schedule: 90 epochs with drops at 30 and 60 >>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 60], ... gamma=0.1 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step() ... print(f"Epoch {epoch}: lr = {optimizer.current_lr}")
CIFAR training schedule:
>>> # CIFAR-10/100 schedule: 200 epochs >>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[60, 120, 160], ... gamma=0.2 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=5e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(200): ... optimizer.step(grads) ... scheduler.step()
Custom aggressive decay schedule:
>>> # Frequent drops for quick convergence >>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[5, 10, 15, 20, 25], ... gamma=0.5 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # lr rapidly decreases at each milestone
Resuming training with state dict:
>>> # Save training state >>> scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 60, 90], ... gamma=0.1 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(50): ... scheduler.step() >>> >>> checkpoint = {'scheduler': scheduler.state_dict(), 'epoch': 50} >>> >>> # Resume later >>> new_scheduler = braintools.optim.MultiStepLR( ... base_lr=0.1, ... milestones=[30, 60, 90], ... gamma=0.1 ... ) >>> new_scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continues from epoch 50 with correct lr
See also
StepLRDecay learning rate at regular intervals
ExponentialLRExponential decay of learning rate
SequentialLRSwitch between different schedulers at milestones
References