ReduceLROnPlateau#

class braintools.optim.ReduceLROnPlateau(base_lr=0.001, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, last_epoch=0)#

Reduce learning rate when a metric has stopped improving - Adaptive LR based on validation metrics.

ReduceLROnPlateau monitors a validation metric (like loss or accuracy) and reduces the learning rate when the metric stops improving for a specified number of epochs (patience). This is useful when you don’t know in advance when to reduce the learning rate, letting the training dynamics determine the schedule.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • mode (str) –

    Whether to minimize or maximize the monitored metric.

    • ’min’: Reduce lr when metric stops decreasing (e.g., for loss)

    • ’max’: Reduce lr when metric stops increasing (e.g., for accuracy)

    Default: ‘min’.

  • factor (float) – Factor by which to reduce the learning rate. new_lr = lr * factor. Must be in range (0, 1). Default: 0.1.

  • patience (int) – Number of epochs with no improvement after which learning rate will be reduced. For example, if patience=5, the first 5 epochs with no improvement are tolerated, and the lr is reduced on the 6th epoch. Default: 10.

  • threshold (float) – Threshold for measuring improvement. Only changes greater than threshold are considered as improvement. Default: 1e-4.

  • threshold_mode (str) –

    How to compute the threshold for improvement.

    • ’rel’: dynamic threshold = best * (1 ± threshold)

    • ’abs’: static threshold = best ± threshold

    Default: ‘rel’.

  • cooldown (int) – Number of epochs to wait before resuming normal operation after lr has been reduced. During cooldown, no further lr reductions occur. Default: 0.

  • min_lr (float | List[float]) – Minimum learning rate(s). The lr will not be reduced below this value. Default: 0.

  • eps (float) – Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8.

  • last_epoch (int) – The index of the last epoch. Used for resuming training. Default: -1.

Notes

The scheduler reduces the learning rate when the monitored metric plateaus:

\[\begin{split}\eta_{t+1} = \begin{cases} \max(\eta_t \cdot \text{factor}, \eta_{\min}) & \text{if plateau detected} \\ \eta_t & \text{otherwise} \end{cases}\end{split}\]

A plateau is detected when the metric fails to improve for patience consecutive epochs.

For mode=’min’, improvement is defined as:

\[\text{metric}_t < \text{best} \cdot (1 - \text{threshold}) \quad \text{(relative)}\]

or

\[\text{metric}_t < \text{best} - \text{threshold} \quad \text{(absolute)}\]

Key characteristics:

  • Adaptive schedule based on training progress

  • No need to pre-specify decay epochs

  • Ideal when optimal schedule is unknown

  • Commonly used for validation-based training

Common configurations:

  • Conservative: patience=10, factor=0.5

  • Moderate: patience=5, factor=0.1

  • Aggressive: patience=3, factor=0.1

When to use:

  • When you don’t know the optimal training schedule

  • For validation-driven training

  • When training dynamics are unpredictable

  • For automatic hyperparameter tuning

Examples

Basic usage with validation loss:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.1,
...     mode='min',
...     factor=0.5,
...     patience=5,
...     min_lr=0.001
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     # Training
...     optimizer.step(grads)
...
...     # Validation
...     val_loss = validate(model, val_loader)
...
...     # Update learning rate based on validation loss
...     scheduler.step(val_loss)
...
...     print(f"Epoch {epoch}: lr={optimizer.current_lr:.6f}, val_loss={val_loss:.4f}")

With validation accuracy (maximize mode):

>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.01,
...     mode='max',  # Maximize accuracy
...     factor=0.1,
...     patience=10,
...     threshold=0.01
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(200):
...     optimizer.step(grads)
...     val_acc = evaluate_accuracy(model, val_loader)
...     scheduler.step(val_acc)

Conservative schedule for stable training:

>>> # Reduce lr by half when no improvement for 10 epochs
>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.1,
...     mode='min',
...     factor=0.5,
...     patience=10,
...     threshold=1e-3,
...     cooldown=5  # Wait 5 epochs after reduction
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9, weight_decay=1e-4)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Aggressive schedule for quick adaptation:

>>> # Reduce lr by 90% when no improvement for 3 epochs
>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.01,
...     mode='min',
...     factor=0.1,
...     patience=3,
...     threshold=1e-4,
...     min_lr=1e-6
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With absolute threshold mode:

>>> # Use absolute threshold for improvement
>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.1,
...     mode='min',
...     factor=0.5,
...     patience=5,
...     threshold=0.001,
...     threshold_mode='abs'  # Absolute improvement threshold
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete training loop with early stopping:

>>> import jax.numpy as jnp
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.01,
...     mode='min',
...     factor=0.5,
...     patience=10,
...     min_lr=1e-5
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> best_loss = float('inf')
>>> patience_counter = 0
>>> early_stop_patience = 20
>>>
>>> for epoch in range(200):
...     # Training
...     optimizer.step(grads)
...
...     # Validation
...     val_loss = validate(model, val_loader)
...
...     # Update learning rate
...     old_lr = optimizer.current_lr
...     scheduler.step(val_loss)
...     if optimizer.current_lr < old_lr:
...         print(f"Epoch {epoch}: Reduced LR to {optimizer.current_lr:.6f}")
...
...     # Early stopping
...     if val_loss < best_loss:
...         best_loss = val_loss
...         patience_counter = 0
...         # Save best model
...     else:
...         patience_counter += 1
...         if patience_counter >= early_stop_patience:
...             print(f"Early stopping at epoch {epoch}")
...             break

State persistence:

>>> scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.1,
...     mode='min',
...     factor=0.5,
...     patience=5
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Train for some epochs
>>> for epoch in range(50):
...     val_loss = train_and_validate(model, optimizer)
...     scheduler.step(val_loss)
>>>
>>> # Save checkpoint
>>> checkpoint = {
...     'epoch': 50,
...     'model': model.state_dict(),
...     'optimizer': optimizer.state_dict(),
...     'scheduler': scheduler.state_dict(),
...     'best_metric': scheduler.best
... }
>>>
>>> # Resume training
>>> new_scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.1,
...     mode='min',
...     factor=0.5,
...     patience=5
... )
>>> new_scheduler.load_state_dict(checkpoint['scheduler'])
>>> new_scheduler.best = checkpoint['best_metric']

Multiple metrics monitoring:

>>> # Monitor different metrics for different purposes
>>> val_scheduler = braintools.optim.ReduceLROnPlateau(
...     base_lr=0.01,
...     mode='min',
...     factor=0.5,
...     patience=5
... )
>>> optimizer = braintools.optim.Adam(lr=val_scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     optimizer.step(grads)
...     val_loss = validate(model, val_loader)
...
...     # Use validation loss for lr scheduling
...     val_scheduler.step(val_loss)
...
...     # Could also track other metrics separately
...     val_acc = evaluate_accuracy(model, val_loader)
...     print(f"Epoch {epoch}: lr={optimizer.current_lr:.6f}, "
...           f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

See also

StepLR

Fixed step-based learning rate decay

ExponentialLR

Exponential decay

CosineAnnealingLR

Cosine annealing schedule

OneCycleLR

One cycle learning rate policy

References

get_lr()[source]#

Calculate learning rate.

step(metric, epoch=None)[source]#

Step with metric value (JIT-compatible).

Parameters:
  • metric (float) – The metric value to monitor.

  • epoch (int | None) – Optional epoch number.