ReduceLROnPlateau#
- class braintools.optim.ReduceLROnPlateau(base_lr=0.001, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, last_epoch=0)#
Reduce learning rate when a metric has stopped improving - Adaptive LR based on validation metrics.
ReduceLROnPlateau monitors a validation metric (like loss or accuracy) and reduces the learning rate when the metric stops improving for a specified number of epochs (patience). This is useful when you don’t know in advance when to reduce the learning rate, letting the training dynamics determine the schedule.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.mode (
str) –Whether to minimize or maximize the monitored metric.
’min’: Reduce lr when metric stops decreasing (e.g., for loss)
’max’: Reduce lr when metric stops increasing (e.g., for accuracy)
Default: ‘min’.
factor (
float) – Factor by which to reduce the learning rate. new_lr = lr * factor. Must be in range (0, 1). Default: 0.1.patience (
int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience=5, the first 5 epochs with no improvement are tolerated, and the lr is reduced on the 6th epoch. Default: 10.threshold (
float) – Threshold for measuring improvement. Only changes greater than threshold are considered as improvement. Default: 1e-4.threshold_mode (
str) –How to compute the threshold for improvement.
’rel’: dynamic threshold = best * (1 ± threshold)
’abs’: static threshold = best ± threshold
Default: ‘rel’.
cooldown (
int) – Number of epochs to wait before resuming normal operation after lr has been reduced. During cooldown, no further lr reductions occur. Default: 0.min_lr (
float|List[float]) – Minimum learning rate(s). The lr will not be reduced below this value. Default: 0.eps (
float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1.
Notes
The scheduler reduces the learning rate when the monitored metric plateaus:
\[\begin{split}\eta_{t+1} = \begin{cases} \max(\eta_t \cdot \text{factor}, \eta_{\min}) & \text{if plateau detected} \\ \eta_t & \text{otherwise} \end{cases}\end{split}\]A plateau is detected when the metric fails to improve for patience consecutive epochs.
For mode=’min’, improvement is defined as:
\[\text{metric}_t < \text{best} \cdot (1 - \text{threshold}) \quad \text{(relative)}\]or
\[\text{metric}_t < \text{best} - \text{threshold} \quad \text{(absolute)}\]Key characteristics:
Adaptive schedule based on training progress
No need to pre-specify decay epochs
Ideal when optimal schedule is unknown
Commonly used for validation-based training
Common configurations:
Conservative: patience=10, factor=0.5
Moderate: patience=5, factor=0.1
Aggressive: patience=3, factor=0.1
When to use:
When you don’t know the optimal training schedule
For validation-driven training
When training dynamics are unpredictable
For automatic hyperparameter tuning
Examples
Basic usage with validation loss:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.1, ... mode='min', ... factor=0.5, ... patience=5, ... min_lr=0.001 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... # Training ... optimizer.step(grads) ... ... # Validation ... val_loss = validate(model, val_loader) ... ... # Update learning rate based on validation loss ... scheduler.step(val_loss) ... ... print(f"Epoch {epoch}: lr={optimizer.current_lr:.6f}, val_loss={val_loss:.4f}")
With validation accuracy (maximize mode):
>>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.01, ... mode='max', # Maximize accuracy ... factor=0.1, ... patience=10, ... threshold=0.01 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(200): ... optimizer.step(grads) ... val_acc = evaluate_accuracy(model, val_loader) ... scheduler.step(val_acc)
Conservative schedule for stable training:
>>> # Reduce lr by half when no improvement for 10 epochs >>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.1, ... mode='min', ... factor=0.5, ... patience=10, ... threshold=1e-3, ... cooldown=5 # Wait 5 epochs after reduction ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Aggressive schedule for quick adaptation:
>>> # Reduce lr by 90% when no improvement for 3 epochs >>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.01, ... mode='min', ... factor=0.1, ... patience=3, ... threshold=1e-4, ... min_lr=1e-6 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With absolute threshold mode:
>>> # Use absolute threshold for improvement >>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.1, ... mode='min', ... factor=0.5, ... patience=5, ... threshold=0.001, ... threshold_mode='abs' # Absolute improvement threshold ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete training loop with early stopping:
>>> import jax.numpy as jnp >>> >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.01, ... mode='min', ... factor=0.5, ... patience=10, ... min_lr=1e-5 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> best_loss = float('inf') >>> patience_counter = 0 >>> early_stop_patience = 20 >>> >>> for epoch in range(200): ... # Training ... optimizer.step(grads) ... ... # Validation ... val_loss = validate(model, val_loader) ... ... # Update learning rate ... old_lr = optimizer.current_lr ... scheduler.step(val_loss) ... if optimizer.current_lr < old_lr: ... print(f"Epoch {epoch}: Reduced LR to {optimizer.current_lr:.6f}") ... ... # Early stopping ... if val_loss < best_loss: ... best_loss = val_loss ... patience_counter = 0 ... # Save best model ... else: ... patience_counter += 1 ... if patience_counter >= early_stop_patience: ... print(f"Early stopping at epoch {epoch}") ... break
State persistence:
>>> scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.1, ... mode='min', ... factor=0.5, ... patience=5 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for some epochs >>> for epoch in range(50): ... val_loss = train_and_validate(model, optimizer) ... scheduler.step(val_loss) >>> >>> # Save checkpoint >>> checkpoint = { ... 'epoch': 50, ... 'model': model.state_dict(), ... 'optimizer': optimizer.state_dict(), ... 'scheduler': scheduler.state_dict(), ... 'best_metric': scheduler.best ... } >>> >>> # Resume training >>> new_scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.1, ... mode='min', ... factor=0.5, ... patience=5 ... ) >>> new_scheduler.load_state_dict(checkpoint['scheduler']) >>> new_scheduler.best = checkpoint['best_metric']
Multiple metrics monitoring:
>>> # Monitor different metrics for different purposes >>> val_scheduler = braintools.optim.ReduceLROnPlateau( ... base_lr=0.01, ... mode='min', ... factor=0.5, ... patience=5 ... ) >>> optimizer = braintools.optim.Adam(lr=val_scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... optimizer.step(grads) ... val_loss = validate(model, val_loader) ... ... # Use validation loss for lr scheduling ... val_scheduler.step(val_loss) ... ... # Could also track other metrics separately ... val_acc = evaluate_accuracy(model, val_loader) ... print(f"Epoch {epoch}: lr={optimizer.current_lr:.6f}, " ... f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
See also
StepLRFixed step-based learning rate decay
ExponentialLRExponential decay
CosineAnnealingLRCosine annealing schedule
OneCycleLROne cycle learning rate policy
References