StepLR

StepLR#

class braintools.optim.StepLR(base_lr=0.001, step_size=30, gamma=0.1, last_epoch=0)#

Step learning rate scheduler - Decays learning rate by gamma every step_size epochs.

StepLR multiplies the learning rate by gamma at regular intervals (every step_size epochs), creating a staircase decay pattern. This is one of the most commonly used learning rate schedules for training deep neural networks.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • step_size (int) – Period of learning rate decay in epochs. The learning rate will be multiplied by gamma every step_size epochs. Default: 30.

  • gamma (float) – Multiplicative factor of learning rate decay. Must be in range (0, 1]. Default: 0.1.

  • last_epoch (int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).

Notes

The learning rate at epoch \(t\) is computed as:

\[\eta_t = \eta_0 \cdot \gamma^{\lfloor t / \text{step_size} \rfloor}\]

where \(\eta_0\) is the initial learning rate (base_lr), and \(\lfloor \cdot \rfloor\) denotes the floor function.

Key characteristics:

  • Creates discrete “steps” in the learning rate schedule

  • Widely used for training image classification models

  • Simple to tune with only two hyperparameters

  • Works well when combined with momentum-based optimizers

Common step_size values:

  • ImageNet training: step_size=30, total_epochs=90 (decay at epochs 30, 60)

  • CIFAR training: step_size=50, total_epochs=150 (decay at epochs 50, 100)

Examples

Basic usage with SGD:

>>> import braintools
>>> import brainstate
>>>
>>> # Create model and scheduler
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training loop
>>> for epoch in range(90):
...     # ... training code ...
...     scheduler.step()
...     if epoch in [0, 29, 30, 59, 60, 89]:
...         print(f"Epoch {epoch}: lr = {optimizer.current_lr:.6f}")
Epoch 0: lr = 0.100000
Epoch 29: lr = 0.100000
Epoch 30: lr = 0.010000  # First decay
Epoch 59: lr = 0.010000
Epoch 60: lr = 0.001000  # Second decay
Epoch 89: lr = 0.001000

Using with Adam optimizer:

>>> scheduler = braintools.optim.StepLR(base_lr=0.001, step_size=10, gamma=0.5)
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(25):
...     # Training step
...     scheduler.step()
# lr decays: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20)

Custom decay schedule:

>>> # Aggressive decay every 5 epochs
>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=5, gamma=0.5)
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # After 15 epochs: lr = 0.1 * 0.5^3 = 0.0125

Saving and loading scheduler state:

>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Train for some epochs
>>> for epoch in range(50):
...     scheduler.step()
>>>
>>> # Save checkpoint
>>> checkpoint = {
...     'epoch': 50,
...     'model': model.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'scheduler': scheduler.state_dict(),
... }
>>>
>>> # Later, resume training
>>> new_scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> new_scheduler.load_state_dict(checkpoint['scheduler'])
>>> # Continue from epoch 50

Multiple parameter groups:

>>> # Different learning rates for different layers
>>> scheduler = braintools.optim.StepLR(
...     base_lr=[0.1, 0.01],  # Different base lr for each group
...     step_size=30,
...     gamma=0.1
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> # Both groups decay by gamma every step_size epochs

Complete training example:

>>> import jax.numpy as jnp
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> def train_epoch(model, optimizer, data):
...     def loss_fn(params):
...         # Compute loss
...         return loss
...     grads = jax.grad(loss_fn)(model.states(brainstate.ParamState))
...     optimizer.update(grads)
>>>
>>> for epoch in range(90):
...     train_epoch(model, optimizer, train_data)
...     scheduler.step()
...     print(f"Epoch {epoch}: lr = {optimizer.current_lr}")

See also

MultiStepLR

Decay learning rate at specific milestone epochs

ExponentialLR

Exponential decay of learning rate

CosineAnnealingLR

Cosine annealing schedule

References

get_lr()[source]#

Calculate learning rate.