StepLR#
- class braintools.optim.StepLR(base_lr=0.001, step_size=30, gamma=0.1, last_epoch=0)#
Step learning rate scheduler - Decays learning rate by gamma every step_size epochs.
StepLR multiplies the learning rate by gamma at regular intervals (every step_size epochs), creating a staircase decay pattern. This is one of the most commonly used learning rate schedules for training deep neural networks.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.step_size (
int) – Period of learning rate decay in epochs. The learning rate will be multiplied by gamma every step_size epochs. Default: 30.gamma (
float) – Multiplicative factor of learning rate decay. Must be in range (0, 1]. Default: 0.1.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\eta_t = \eta_0 \cdot \gamma^{\lfloor t / \text{step_size} \rfloor}\]where \(\eta_0\) is the initial learning rate (base_lr), and \(\lfloor \cdot \rfloor\) denotes the floor function.
Key characteristics:
Creates discrete “steps” in the learning rate schedule
Widely used for training image classification models
Simple to tune with only two hyperparameters
Works well when combined with momentum-based optimizers
Common step_size values:
ImageNet training: step_size=30, total_epochs=90 (decay at epochs 30, 60)
CIFAR training: step_size=50, total_epochs=150 (decay at epochs 50, 100)
Examples
Basic usage with SGD:
>>> import braintools >>> import brainstate >>> >>> # Create model and scheduler >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training loop >>> for epoch in range(90): ... # ... training code ... ... scheduler.step() ... if epoch in [0, 29, 30, 59, 60, 89]: ... print(f"Epoch {epoch}: lr = {optimizer.current_lr:.6f}") Epoch 0: lr = 0.100000 Epoch 29: lr = 0.100000 Epoch 30: lr = 0.010000 # First decay Epoch 59: lr = 0.010000 Epoch 60: lr = 0.001000 # Second decay Epoch 89: lr = 0.001000
Using with Adam optimizer:
>>> scheduler = braintools.optim.StepLR(base_lr=0.001, step_size=10, gamma=0.5) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(25): ... # Training step ... scheduler.step() # lr decays: 0.001 -> 0.0005 (epoch 10) -> 0.00025 (epoch 20)
Custom decay schedule:
>>> # Aggressive decay every 5 epochs >>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=5, gamma=0.5) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # After 15 epochs: lr = 0.1 * 0.5^3 = 0.0125
Saving and loading scheduler state:
>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(50): ... scheduler.step() >>> >>> # Save checkpoint >>> checkpoint = { ... 'epoch': 50, ... 'model': model.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'scheduler': scheduler.state_dict(), ... } >>> >>> # Later, resume training >>> new_scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> new_scheduler.load_state_dict(checkpoint['scheduler']) >>> # Continue from epoch 50
Multiple parameter groups:
>>> # Different learning rates for different layers >>> scheduler = braintools.optim.StepLR( ... base_lr=[0.1, 0.01], # Different base lr for each group ... step_size=30, ... gamma=0.1 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> # Both groups decay by gamma every step_size epochs
Complete training example:
>>> import jax.numpy as jnp >>> >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> def train_epoch(model, optimizer, data): ... def loss_fn(params): ... # Compute loss ... return loss ... grads = jax.grad(loss_fn)(model.states(brainstate.ParamState)) ... optimizer.update(grads) >>> >>> for epoch in range(90): ... train_epoch(model, optimizer, train_data) ... scheduler.step() ... print(f"Epoch {epoch}: lr = {optimizer.current_lr}")
See also
MultiStepLRDecay learning rate at specific milestone epochs
ExponentialLRExponential decay of learning rate
CosineAnnealingLRCosine annealing schedule
References