PiecewiseConstantSchedule#
- class braintools.optim.PiecewiseConstantSchedule(base_lr=0.001, boundaries=None, values=None, last_epoch=0)#
Piecewise constant learning rate schedule with step-wise transitions.
PiecewiseConstantSchedule implements a learning rate schedule where the learning rate remains constant within specified intervals and changes abruptly at predefined boundaries. This creates a step function that’s useful for stage-based training where different phases require different learning rates.
This scheduler is particularly effective for: - Multi-stage training pipelines - Transfer learning with progressive unfreezing - Training with curriculum learning - Reproducing specific research papers with fixed schedules - Budget-constrained training with predetermined phases
- Parameters:
base_lr (
float|List[float]) – Base learning rate(s) that will be scaled by the values parameter. Can be a single float or list for multiple parameter groups. This serves as a reference that gets multiplied by the values at each stage. Default: 1e-3.boundaries (
List[int]) – Step indices where the learning rate changes. Must be sorted in ascending order. The schedule will have len(boundaries) + 1 distinct phases. Default: [1000, 2000].values (
List[float]) – Multiplicative factors for the base learning rate in each phase. Must have exactly len(boundaries) + 1 elements. The i-th value applies from boundary[i-1] to boundary[i]. Default: [1.0, 0.1, 0.01].last_epoch (
int) – The index of the last epoch. Used when resuming training. Default: 0.
Notes
Mathematical Formulation:
The learning rate at step t is defined as:
\[\eta_t = \eta_{base} \times v_i\]where \(v_i\) is determined by:
\[\begin{split}v_i = \begin{cases} \text{values}[0] & \text{if } t < \text{boundaries}[0] \\ \text{values}[1] & \text{if } \text{boundaries}[0] \leq t < \text{boundaries}[1] \\ ... & ... \\ \text{values}[n] & \text{if } t \geq \text{boundaries}[n-1] \end{cases}\end{split}\]Schedule Structure:
Given boundaries [b1, b2, …, bn] and values [v0, v1, …, vn]:
Steps [0, b1): learning_rate = base_lr × v0
Steps [b1, b2): learning_rate = base_lr × v1
Steps [b2, b3): learning_rate = base_lr × v2
…
Steps [bn, ∞): learning_rate = base_lr × vn
JIT Compatibility:
This implementation is JIT-compatible through the use of jnp.searchsorted for efficient boundary-based value selection without Python conditionals.
Examples
Classic ImageNet training schedule:
>>> import braintools >>> import brainstate >>> >>> # ResNet50 on ImageNet schedule >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[30, 60, 80], # Epochs to decrease LR ... values=[1.0, 0.1, 0.01, 0.001] # LR multipliers ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Train for 90 epochs total >>> for epoch in range(90): ... train_epoch(...) ... scheduler.step() ... lr = scheduler.get_lr()[0] ... print(f"Epoch {epoch}: LR = {lr:.6f}") ... # LR: 0.1 (epochs 0-29), 0.01 (30-59), 0.001 (60-79), 0.0001 (80-89)
Transfer learning with progressive unfreezing:
>>> # Unfreeze layers progressively with different LRs >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=1e-3, ... boundaries=[5, 10, 15], # Unfreeze stages ... values=[0.01, 0.1, 0.5, 1.0] # Gradual increase ... ) >>> >>> # Stage 1 (0-4): Only train head, very low LR >>> # Stage 2 (5-9): Unfreeze top layers >>> # Stage 3 (10-14): Unfreeze middle layers >>> # Stage 4 (15+): Full model training >>> >>> for epoch in range(20): ... if epoch == 5: ... unfreeze_top_layers(model) ... elif epoch == 10: ... unfreeze_middle_layers(model) ... elif epoch == 15: ... unfreeze_all_layers(model) ... ... train_epoch(...) ... scheduler.step()
Budget-aware training schedule:
>>> # Training with computational budget constraints >>> # Fast initial training, then careful fine-tuning >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.01, ... boundaries=[100, 500, 800], ... values=[10.0, 1.0, 0.1, 0.01] # Aggressive start, careful end ... ) >>> >>> # Steps 0-99: Fast exploration (LR=0.1) >>> # Steps 100-499: Normal training (LR=0.01) >>> # Steps 500-799: Fine-tuning (LR=0.001) >>> # Steps 800+: Final refinement (LR=0.0001)
Multi-phase curriculum learning:
>>> # Different learning rates for different curriculum stages >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=1e-3, ... boundaries=[1000, 3000, 6000, 9000], ... values=[0.1, 0.5, 1.0, 0.5, 0.1] ... ) >>> >>> curriculum_difficulties = [0.2, 0.4, 0.6, 0.8, 1.0] >>> >>> for step in range(10000): ... # Determine curriculum difficulty ... stage = sum(step >= b for b in scheduler.boundaries) ... difficulty = curriculum_difficulties[stage] ... ... # Train with appropriate difficulty ... batch = get_curriculum_batch(difficulty) ... train_step(batch) ... scheduler.step()
Reproducing paper schedules:
>>> # WideResNet schedule from paper >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[60, 120, 160], # Specific to WideResNet ... values=[1.0, 0.2, 0.04, 0.008] ... ) >>> >>> # CIFAR training schedule from "Bag of Tricks" >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[150, 250], ... values=[1.0, 0.1, 0.01] ... )
Step-based (not epoch-based) scheduling:
>>> # Define boundaries in terms of training steps >>> steps_per_epoch = len(train_loader) >>> epoch_boundaries = [30, 60, 80] # Desired epoch boundaries >>> step_boundaries = [e * steps_per_epoch for e in epoch_boundaries] >>> >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=step_boundaries, ... values=[1.0, 0.1, 0.01, 0.001] ... ) >>> >>> global_step = 0 >>> for epoch in range(90): ... for batch in train_loader: ... train_step(batch) ... scheduler.step(global_step) ... global_step += 1
Combining with warmup:
>>> # Add warmup to piecewise schedule >>> warmup_steps = 500 >>> main_boundaries = [5000, 10000, 15000] >>> >>> # Combine warmup with main schedule >>> all_boundaries = [warmup_steps] + main_boundaries >>> all_values = [0.01, 1.0, 0.1, 0.01, 0.001] # Low start for warmup >>> >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.01, ... boundaries=all_boundaries, ... values=all_values ... )
Dynamic monitoring and adjustment:
>>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[1000, 2000, 3000], ... values=[1.0, 0.5, 0.1, 0.01] ... ) >>> >>> for step in range(4000): ... old_lr = scheduler.get_lr()[0] ... scheduler.step() ... new_lr = scheduler.get_lr()[0] ... ... # Detect LR changes ... if old_lr != new_lr: ... print(f"Step {step}: LR changed from {old_lr:.6f} to {new_lr:.6f}") ... # Optionally reset momentum or other optimizer states ... reset_optimizer_momentum(optimizer) ... ... if step % 100 == 0: ... print(f"Step {step}: LR = {new_lr:.6f}")
Research experimentation with multiple schedules:
>>> # Compare different decay strategies >>> schedules = { ... 'aggressive': braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[10, 20], ... values=[1.0, 0.01, 0.0001] ... ), ... 'conservative': braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[30, 60], ... values=[1.0, 0.5, 0.1] ... ), ... 'multi_stage': braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[10, 20, 30, 40], ... values=[1.0, 0.8, 0.4, 0.1, 0.01] ... ) ... } >>> >>> # Run experiments >>> for name, scheduler in schedules.items(): ... print(f"Testing schedule: {name}") ... model = create_model() ... optimizer = braintools.optim.SGD(lr=scheduler) ... results = train_model(model, optimizer) ... log_results(name, results)
State persistence and checkpointing:
>>> # Save and restore schedule state >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[1000, 2000], ... values=[1.0, 0.1, 0.01] ... ) >>> >>> # Train for some steps >>> for step in range(1500): ... train_step(...) ... scheduler.step() >>> >>> # Save checkpoint >>> checkpoint = { ... 'step': 1500, ... 'scheduler_state': scheduler.state_dict(), ... 'model_state': model.state_dict() ... } >>> save(checkpoint, 'checkpoint.pkl') >>> >>> # Later: restore and continue >>> scheduler = braintools.optim.PiecewiseConstantSchedule( ... base_lr=0.1, ... boundaries=[1000, 2000], ... values=[1.0, 0.1, 0.01] ... ) >>> scheduler.load_state_dict(checkpoint['scheduler_state']) >>> # Continue from step 1500 with correct LR
See also
StepLRExponential decay at regular intervals
MultiStepLRSimilar concept with multiplicative decay
CosineAnnealingLRSmooth transitions instead of step changes
SequentialLRChain multiple schedulers sequentially
References