ConstantLR#
- class braintools.optim.ConstantLR(base_lr=0.001, factor=0.3333333333333333, total_iters=5, last_epoch=0)#
Constant learning rate scheduler - Multiplies learning rate by a constant factor.
ConstantLR multiplies the base learning rate by a constant factor for a specified number of epochs (total_iters), then returns to the original base learning rate. This is useful for implementing warmup phases or temporary learning rate adjustments.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.factor (
float) – Multiplicative factor applied to base_lr for the first total_iters epochs. Must be in range (0, 1]. Default: 1/3.total_iters (
int) – Number of epochs to apply the factor. After total_iters epochs, the learning rate returns to base_lr. Default: 5.last_epoch (
int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).
Notes
The learning rate at epoch \(t\) is computed as:
\[\begin{split}\eta_t = \begin{cases} \eta_0 \cdot \text{factor} & \text{if } t < \text{total_iters} \\ \eta_0 & \text{otherwise} \end{cases}\end{split}\]where \(\eta_0\) is the base learning rate.
Key characteristics:
Simple two-phase learning rate schedule
Commonly used for warmup with constant reduced lr
Automatically returns to base_lr after warmup period
No gradual transition (step change at total_iters)
Comparison with LinearLR:
ConstantLR: Instant jump from (factor * base_lr) to base_lr at total_iters
LinearLR: Smooth linear transition from start_factor to end_factor
Examples
Basic constant warmup:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> # Use 0.5 * base_lr for first 10 epochs, then full base_lr >>> scheduler = braintools.optim.ConstantLR( ... base_lr=0.001, ... factor=0.5, ... total_iters=10 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Epochs 0-9: lr = 0.0005 >>> # Epochs 10+: lr = 0.001
Default warmup configuration:
>>> # Default: lr = base_lr/3 for 5 epochs, then lr = base_lr >>> scheduler = braintools.optim.ConstantLR() >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(10): ... optimizer.step(grads) ... scheduler.step() ... print(f"Epoch {epoch}: lr = {optimizer.current_lr}") # First 5 epochs: lr ≈ 0.000333 # Remaining epochs: lr = 0.001
Short warmup for fine-tuning:
>>> # Use 20% of base_lr for first 3 epochs >>> scheduler = braintools.optim.ConstantLR( ... base_lr=0.0001, ... factor=0.2, ... total_iters=3 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Epochs 0-2: lr = 0.00002 >>> # Epochs 3+: lr = 0.0001
Combining with StepLR:
>>> # Warmup, then step decay >>> warmup = braintools.optim.ConstantLR( ... base_lr=0.1, ... factor=0.1, ... total_iters=5 ... ) >>> decay = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> scheduler = braintools.optim.ChainedScheduler([warmup, decay]) >>> >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(90): ... optimizer.step(grads) ... scheduler.step() # Epochs 0-4: lr = 0.01 (warmup) # Epochs 5-29: lr = 0.1 (after warmup) # Epochs 30-59: lr = 0.01 (first decay) # Epochs 60+: lr = 0.001 (second decay)
Conservative start for transfer learning:
>>> # Start with very low lr for stability >>> scheduler = braintools.optim.ConstantLR( ... base_lr=0.001, ... factor=0.01, ... total_iters=10 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # First 10 epochs: lr = 0.00001 (conservative) >>> # Remaining epochs: lr = 0.001 (normal training)
Multiple parameter groups:
>>> # Different base_lr for different layers >>> scheduler = braintools.optim.ConstantLR( ... base_lr=[0.1, 0.01], ... factor=0.1, ... total_iters=5 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> # Both groups use factor=0.1 for first 5 epochs
Complete training workflow:
>>> import jax.numpy as jnp >>> >>> model = brainstate.nn.Linear(10, 5) >>> scheduler = braintools.optim.ConstantLR( ... base_lr=0.01, ... factor=0.1, ... total_iters=5 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(50): ... # Training step ... for batch in train_loader: ... loss = compute_loss(model, batch) ... grads = jax.grad(compute_loss)(model.states(brainstate.ParamState)) ... optimizer.update(grads) ... ... scheduler.step() ... if epoch in [0, 4, 5, 10]: ... print(f"Epoch {epoch}: lr = {optimizer.current_lr}")
See also
LinearLRLinearly scale learning rate (smooth transition)
WarmupSchedulerAlternative warmup implementation
ChainedSchedulerCombine multiple schedulers
References