PolynomialLR#
- class braintools.optim.PolynomialLR(base_lr=0.001, total_iters=5, power=1.0, last_epoch=0)#
Polynomial learning rate scheduler - Decays learning rate using polynomial function.
PolynomialLR decreases the learning rate according to a polynomial decay schedule. The learning rate is multiplied by a decay factor that follows the formula (1 - t/T)^power, where t is the current epoch and T is total_iters. This provides smooth decay with controllable rate via the power parameter.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.total_iters (
int) – Number of epochs over which to decay the learning rate. After total_iters epochs, the learning rate becomes 0. Default: 5.power (
float) –The power of the polynomial. Controls the shape of the decay curve.
power=1.0: Linear decay
power>1.0: Slower initial decay, faster later
power<1.0: Faster initial decay, slower later
Default: 1.0.
last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\eta_t = \eta_0 \cdot \left(1 - \frac{\min(t, T)}{T}\right)^p\]where \(\eta_0\) is the initial learning rate (base_lr), \(T\) is total_iters, \(t\) is the current epoch, and \(p\) is the power parameter.
Key characteristics:
Smooth polynomial decay to zero (or near-zero)
Decay shape controlled by power parameter
Learning rate reaches 0 at total_iters
Commonly used in semantic segmentation and detection tasks
Power parameter effects:
power=0.5: Square root decay (very fast initial decay)
power=1.0: Linear decay (constant rate)
power=2.0: Quadratic decay (slow initial, fast final)
power=3.0: Cubic decay (very slow initial, very fast final)
When to use:
Training semantic segmentation models (DeepLab, FCN)
Object detection training (YOLO, RetinaNet)
When you want smooth decay to very low learning rates
Tasks that benefit from extended low-lr fine-tuning
Examples
Basic linear decay (power=1.0):
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> # Linear decay from 0.1 to 0 over 100 epochs >>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.1, ... total_iters=100, ... power=1.0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... scheduler.step() # lr decreases linearly: 0.1, 0.099, 0.098, ..., 0.001, 0
Quadratic decay (power=2.0):
>>> # Slower initial decay, faster later decay >>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.1, ... total_iters=100, ... power=2.0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... scheduler.step() # lr: epoch 25 ≈ 0.056, epoch 50 ≈ 0.025, epoch 75 ≈ 0.006
Square root decay (power=0.5):
>>> # Faster initial decay, slower later decay >>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.01, ... total_iters=50, ... power=0.5 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(50): ... optimizer.step(grads) ... scheduler.step()
Semantic segmentation training (DeepLab style):
>>> # Common setup for semantic segmentation >>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.007, ... total_iters=30000, # Iterations, not epochs ... power=0.9 ... ) >>> optimizer = braintools.optim.SGD( ... lr=scheduler, ... momentum=0.9, ... weight_decay=5e-4 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for iteration in range(30000): ... train_step(model, optimizer, batch) ... scheduler.step()
Short training with steep decay:
>>> # Quick decay for fine-tuning >>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.001, ... total_iters=10, ... power=1.0 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(10): ... finetune_epoch(model, optimizer, finetune_loader) ... scheduler.step()
With warmup:
>>> # Warmup followed by polynomial decay >>> warmup = braintools.optim.LinearLR( ... start_factor=0.1, ... end_factor=1.0, ... total_iters=5 ... ) >>> poly_decay = braintools.optim.PolynomialLR( ... base_lr=0.01, ... total_iters=95, ... power=0.9 ... ) >>> scheduler = braintools.optim.ChainedScheduler([warmup, poly_decay]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... scheduler.step()
State persistence:
>>> scheduler = braintools.optim.PolynomialLR( ... base_lr=0.1, ... total_iters=100, ... power=2.0 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(50): ... scheduler.step() >>> >>> # Save checkpoint >>> checkpoint = { ... 'epoch': 50, ... 'scheduler': scheduler.state_dict(), ... } >>> >>> # Resume training >>> new_scheduler = braintools.optim.PolynomialLR( ... base_lr=0.1, ... total_iters=100, ... power=2.0 ... ) >>> new_scheduler.load_state_dict(checkpoint['scheduler'])
Comparison of power values:
>>> # Visualize different power values >>> import matplotlib.pyplot as plt >>> import numpy as np >>> >>> powers = [0.5, 1.0, 2.0, 3.0] >>> total_iters = 100 >>> base_lr = 0.1 >>> >>> for power in powers: ... scheduler = braintools.optim.PolynomialLR( ... base_lr=base_lr, ... total_iters=total_iters, ... power=power ... ) ... lrs = [] ... for _ in range(total_iters): ... lrs.append(scheduler.current_lrs.value[0]) ... scheduler.step() ... plt.plot(lrs, label=f'power={power}') >>> >>> plt.xlabel('Epoch') >>> plt.ylabel('Learning Rate') >>> plt.legend() >>> plt.title('Polynomial LR Decay with Different Powers') >>> plt.show()
See also
LinearLRLinear learning rate scaling (special case with power=1.0)
ExponentialLRExponential decay
CosineAnnealingLRCosine annealing schedule
References