PiecewiseConstantSchedule#

class braintools.optim.PiecewiseConstantSchedule(base_lr=0.001, boundaries=None, values=None, last_epoch=0)#

Piecewise constant learning rate schedule with step-wise transitions.

PiecewiseConstantSchedule implements a learning rate schedule where the learning rate remains constant within specified intervals and changes abruptly at predefined boundaries. This creates a step function that’s useful for stage-based training where different phases require different learning rates.

This scheduler is particularly effective for: - Multi-stage training pipelines - Transfer learning with progressive unfreezing - Training with curriculum learning - Reproducing specific research papers with fixed schedules - Budget-constrained training with predetermined phases

Parameters:
  • base_lr (float | List[float]) – Base learning rate(s) that will be scaled by the values parameter. Can be a single float or list for multiple parameter groups. This serves as a reference that gets multiplied by the values at each stage. Default: 1e-3.

  • boundaries (List[int]) – Step indices where the learning rate changes. Must be sorted in ascending order. The schedule will have len(boundaries) + 1 distinct phases. Default: [1000, 2000].

  • values (List[float]) – Multiplicative factors for the base learning rate in each phase. Must have exactly len(boundaries) + 1 elements. The i-th value applies from boundary[i-1] to boundary[i]. Default: [1.0, 0.1, 0.01].

  • last_epoch (int) – The index of the last epoch. Used when resuming training. Default: 0.

Notes

Mathematical Formulation:

The learning rate at step t is defined as:

\[\eta_t = \eta_{base} \times v_i\]

where \(v_i\) is determined by:

\[\begin{split}v_i = \begin{cases} \text{values}[0] & \text{if } t < \text{boundaries}[0] \\ \text{values}[1] & \text{if } \text{boundaries}[0] \leq t < \text{boundaries}[1] \\ ... & ... \\ \text{values}[n] & \text{if } t \geq \text{boundaries}[n-1] \end{cases}\end{split}\]

Schedule Structure:

Given boundaries [b1, b2, …, bn] and values [v0, v1, …, vn]:

  • Steps [0, b1): learning_rate = base_lr × v0

  • Steps [b1, b2): learning_rate = base_lr × v1

  • Steps [b2, b3): learning_rate = base_lr × v2

  • Steps [bn, ∞): learning_rate = base_lr × vn

JIT Compatibility:

This implementation is JIT-compatible through the use of jnp.searchsorted for efficient boundary-based value selection without Python conditionals.

Examples

Classic ImageNet training schedule:

>>> import braintools
>>> import brainstate
>>>
>>> # ResNet50 on ImageNet schedule
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[30, 60, 80],  # Epochs to decrease LR
...     values=[1.0, 0.1, 0.01, 0.001]  # LR multipliers
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Train for 90 epochs total
>>> for epoch in range(90):
...     train_epoch(...)
...     scheduler.step()
...     lr = scheduler.get_lr()[0]
...     print(f"Epoch {epoch}: LR = {lr:.6f}")
...     # LR: 0.1 (epochs 0-29), 0.01 (30-59), 0.001 (60-79), 0.0001 (80-89)

Transfer learning with progressive unfreezing:

>>> # Unfreeze layers progressively with different LRs
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=1e-3,
...     boundaries=[5, 10, 15],  # Unfreeze stages
...     values=[0.01, 0.1, 0.5, 1.0]  # Gradual increase
... )
>>>
>>> # Stage 1 (0-4): Only train head, very low LR
>>> # Stage 2 (5-9): Unfreeze top layers
>>> # Stage 3 (10-14): Unfreeze middle layers
>>> # Stage 4 (15+): Full model training
>>>
>>> for epoch in range(20):
...     if epoch == 5:
...         unfreeze_top_layers(model)
...     elif epoch == 10:
...         unfreeze_middle_layers(model)
...     elif epoch == 15:
...         unfreeze_all_layers(model)
...
...     train_epoch(...)
...     scheduler.step()

Budget-aware training schedule:

>>> # Training with computational budget constraints
>>> # Fast initial training, then careful fine-tuning
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.01,
...     boundaries=[100, 500, 800],
...     values=[10.0, 1.0, 0.1, 0.01]  # Aggressive start, careful end
... )
>>>
>>> # Steps 0-99: Fast exploration (LR=0.1)
>>> # Steps 100-499: Normal training (LR=0.01)
>>> # Steps 500-799: Fine-tuning (LR=0.001)
>>> # Steps 800+: Final refinement (LR=0.0001)

Multi-phase curriculum learning:

>>> # Different learning rates for different curriculum stages
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=1e-3,
...     boundaries=[1000, 3000, 6000, 9000],
...     values=[0.1, 0.5, 1.0, 0.5, 0.1]
... )
>>>
>>> curriculum_difficulties = [0.2, 0.4, 0.6, 0.8, 1.0]
>>>
>>> for step in range(10000):
...     # Determine curriculum difficulty
...     stage = sum(step >= b for b in scheduler.boundaries)
...     difficulty = curriculum_difficulties[stage]
...
...     # Train with appropriate difficulty
...     batch = get_curriculum_batch(difficulty)
...     train_step(batch)
...     scheduler.step()

Reproducing paper schedules:

>>> # WideResNet schedule from paper
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[60, 120, 160],  # Specific to WideResNet
...     values=[1.0, 0.2, 0.04, 0.008]
... )
>>>
>>> # CIFAR training schedule from "Bag of Tricks"
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[150, 250],
...     values=[1.0, 0.1, 0.01]
... )

Step-based (not epoch-based) scheduling:

>>> # Define boundaries in terms of training steps
>>> steps_per_epoch = len(train_loader)
>>> epoch_boundaries = [30, 60, 80]  # Desired epoch boundaries
>>> step_boundaries = [e * steps_per_epoch for e in epoch_boundaries]
>>>
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=step_boundaries,
...     values=[1.0, 0.1, 0.01, 0.001]
... )
>>>
>>> global_step = 0
>>> for epoch in range(90):
...     for batch in train_loader:
...         train_step(batch)
...         scheduler.step(global_step)
...         global_step += 1

Combining with warmup:

>>> # Add warmup to piecewise schedule
>>> warmup_steps = 500
>>> main_boundaries = [5000, 10000, 15000]
>>>
>>> # Combine warmup with main schedule
>>> all_boundaries = [warmup_steps] + main_boundaries
>>> all_values = [0.01, 1.0, 0.1, 0.01, 0.001]  # Low start for warmup
>>>
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.01,
...     boundaries=all_boundaries,
...     values=all_values
... )

Dynamic monitoring and adjustment:

>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[1000, 2000, 3000],
...     values=[1.0, 0.5, 0.1, 0.01]
... )
>>>
>>> for step in range(4000):
...     old_lr = scheduler.get_lr()[0]
...     scheduler.step()
...     new_lr = scheduler.get_lr()[0]
...
...     # Detect LR changes
...     if old_lr != new_lr:
...         print(f"Step {step}: LR changed from {old_lr:.6f} to {new_lr:.6f}")
...         # Optionally reset momentum or other optimizer states
...         reset_optimizer_momentum(optimizer)
...
...     if step % 100 == 0:
...         print(f"Step {step}: LR = {new_lr:.6f}")

Research experimentation with multiple schedules:

>>> # Compare different decay strategies
>>> schedules = {
...     'aggressive': braintools.optim.PiecewiseConstantSchedule(
...         base_lr=0.1,
...         boundaries=[10, 20],
...         values=[1.0, 0.01, 0.0001]
...     ),
...     'conservative': braintools.optim.PiecewiseConstantSchedule(
...         base_lr=0.1,
...         boundaries=[30, 60],
...         values=[1.0, 0.5, 0.1]
...     ),
...     'multi_stage': braintools.optim.PiecewiseConstantSchedule(
...         base_lr=0.1,
...         boundaries=[10, 20, 30, 40],
...         values=[1.0, 0.8, 0.4, 0.1, 0.01]
...     )
... }
>>>
>>> # Run experiments
>>> for name, scheduler in schedules.items():
...     print(f"Testing schedule: {name}")
...     model = create_model()
...     optimizer = braintools.optim.SGD(lr=scheduler)
...     results = train_model(model, optimizer)
...     log_results(name, results)

State persistence and checkpointing:

>>> # Save and restore schedule state
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[1000, 2000],
...     values=[1.0, 0.1, 0.01]
... )
>>>
>>> # Train for some steps
>>> for step in range(1500):
...     train_step(...)
...     scheduler.step()
>>>
>>> # Save checkpoint
>>> checkpoint = {
...     'step': 1500,
...     'scheduler_state': scheduler.state_dict(),
...     'model_state': model.state_dict()
... }
>>> save(checkpoint, 'checkpoint.pkl')
>>>
>>> # Later: restore and continue
>>> scheduler = braintools.optim.PiecewiseConstantSchedule(
...     base_lr=0.1,
...     boundaries=[1000, 2000],
...     values=[1.0, 0.1, 0.01]
... )
>>> scheduler.load_state_dict(checkpoint['scheduler_state'])
>>> # Continue from step 1500 with correct LR

See also

StepLR

Exponential decay at regular intervals

MultiStepLR

Similar concept with multiplicative decay

CosineAnnealingLR

Smooth transitions instead of step changes

SequentialLR

Chain multiple schedulers sequentially

References

get_lr()[source]#

Calculate learning rate.