MultiStepLR

MultiStepLR#

class braintools.optim.MultiStepLR(base_lr=0.001, milestones=(30, 60, 90), gamma=0.1, last_epoch=0)#

Multi-step learning rate scheduler - Decays learning rate at specific milestone epochs.

MultiStepLR reduces the learning rate by a factor of gamma at each epoch specified in the milestones list. This provides more flexible control than StepLR, allowing you to schedule learning rate drops at arbitrary points during training.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • milestones (Sequence[int]) – List of epoch indices at which to decay the learning rate. Must be increasing. Default: (30, 60, 90).

  • gamma (float) – Multiplicative factor of learning rate decay. Must be in range (0, 1]. Default: 0.1.

  • last_epoch (int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).

Notes

The learning rate at epoch \(t\) is computed as:

\[\eta_t = \eta_0 \cdot \gamma^{|\{m \in \text{milestones} : m \leq t\}|}\]

where \(\eta_0\) is the initial learning rate (base_lr), and \(|\{m \in \text{milestones} : m \leq t\}|\) counts how many milestones have been reached by epoch \(t\).

Key characteristics:

  • Provides precise control over when learning rate changes occur

  • Ideal when you know specific epochs where model learning plateaus

  • Commonly used in research papers with fixed training schedules

  • Each milestone multiplies the current lr by gamma

Common milestone patterns:

  • ImageNet (90 epochs): milestones=[30, 60], gamma=0.1

  • CIFAR (200 epochs): milestones=[60, 120, 160], gamma=0.2

  • Fine-tuning: milestones=[10, 20], gamma=0.5

Examples

Basic usage with predefined milestones:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 80],
...     gamma=0.1
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # lr schedule:
>>> # epochs 0-29:  lr = 0.1
>>> # epochs 30-79: lr = 0.01  (after 1st milestone)
>>> # epochs 80+:   lr = 0.001 (after 2nd milestone)

Using with Adam for fine-tuning:

>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.001,
...     milestones=[10, 20, 30],
...     gamma=0.5
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(40):
...     # Training code
...     scheduler.step()
# lr: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20) -> 0.000125 (epoch 30)

ImageNet-style training schedule:

>>> # Standard ImageNet schedule: 90 epochs with drops at 30 and 60
>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 60],
...     gamma=0.1
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(90):
...     optimizer.step(grads)
...     scheduler.step()
...     print(f"Epoch {epoch}: lr = {optimizer.current_lr}")

CIFAR training schedule:

>>> # CIFAR-10/100 schedule: 200 epochs
>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[60, 120, 160],
...     gamma=0.2
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=5e-4)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(200):
...     optimizer.step(grads)
...     scheduler.step()

Custom aggressive decay schedule:

>>> # Frequent drops for quick convergence
>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[5, 10, 15, 20, 25],
...     gamma=0.5
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # lr rapidly decreases at each milestone

Resuming training with state dict:

>>> # Save training state
>>> scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 60, 90],
...     gamma=0.1
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(50):
...     scheduler.step()
>>>
>>> checkpoint = {'scheduler': scheduler.state_dict(), 'epoch': 50}
>>>
>>> # Resume later
>>> new_scheduler = braintools.optim.MultiStepLR(
...     base_lr=0.1,
...     milestones=[30, 60, 90],
...     gamma=0.1
... )
>>> new_scheduler.load_state_dict(checkpoint['scheduler'])
>>> # Continues from epoch 50 with correct lr

See also

StepLR

Decay learning rate at regular intervals

ExponentialLR

Exponential decay of learning rate

SequentialLR

Switch between different schedulers at milestones

References

get_lr()[source]#

Calculate learning rate.