ConstantLR

ConstantLR#

class braintools.optim.ConstantLR(base_lr=0.001, factor=0.3333333333333333, total_iters=5, last_epoch=0)#

Constant learning rate scheduler - Multiplies learning rate by a constant factor.

ConstantLR multiplies the base learning rate by a constant factor for a specified number of epochs (total_iters), then returns to the original base learning rate. This is useful for implementing warmup phases or temporary learning rate adjustments.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • factor (float) – Multiplicative factor applied to base_lr for the first total_iters epochs. Must be in range (0, 1]. Default: 1/3.

  • total_iters (int) – Number of epochs to apply the factor. After total_iters epochs, the learning rate returns to base_lr. Default: 5.

  • last_epoch (int) – The index of the last epoch. Used for resuming training. Default: -1 (starts from beginning).

Notes

The learning rate at epoch \(t\) is computed as:

\[\begin{split}\eta_t = \begin{cases} \eta_0 \cdot \text{factor} & \text{if } t < \text{total_iters} \\ \eta_0 & \text{otherwise} \end{cases}\end{split}\]

where \(\eta_0\) is the base learning rate.

Key characteristics:

  • Simple two-phase learning rate schedule

  • Commonly used for warmup with constant reduced lr

  • Automatically returns to base_lr after warmup period

  • No gradual transition (step change at total_iters)

Comparison with LinearLR:

  • ConstantLR: Instant jump from (factor * base_lr) to base_lr at total_iters

  • LinearLR: Smooth linear transition from start_factor to end_factor

Examples

Basic constant warmup:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> # Use 0.5 * base_lr for first 10 epochs, then full base_lr
>>> scheduler = braintools.optim.ConstantLR(
...     base_lr=0.001,
...     factor=0.5,
...     total_iters=10
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Epochs 0-9:  lr = 0.0005
>>> # Epochs 10+:  lr = 0.001

Default warmup configuration:

>>> # Default: lr = base_lr/3 for 5 epochs, then lr = base_lr
>>> scheduler = braintools.optim.ConstantLR()
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(10):
...     optimizer.step(grads)
...     scheduler.step()
...     print(f"Epoch {epoch}: lr = {optimizer.current_lr}")
# First 5 epochs: lr ≈ 0.000333
# Remaining epochs: lr = 0.001

Short warmup for fine-tuning:

>>> # Use 20% of base_lr for first 3 epochs
>>> scheduler = braintools.optim.ConstantLR(
...     base_lr=0.0001,
...     factor=0.2,
...     total_iters=3
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Epochs 0-2:  lr = 0.00002
>>> # Epochs 3+:   lr = 0.0001

Combining with StepLR:

>>> # Warmup, then step decay
>>> warmup = braintools.optim.ConstantLR(
...     base_lr=0.1,
...     factor=0.1,
...     total_iters=5
... )
>>> decay = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> scheduler = braintools.optim.ChainedScheduler([warmup, decay])
>>>
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(90):
...     optimizer.step(grads)
...     scheduler.step()
# Epochs 0-4:   lr = 0.01 (warmup)
# Epochs 5-29:  lr = 0.1  (after warmup)
# Epochs 30-59: lr = 0.01 (first decay)
# Epochs 60+:   lr = 0.001 (second decay)

Conservative start for transfer learning:

>>> # Start with very low lr for stability
>>> scheduler = braintools.optim.ConstantLR(
...     base_lr=0.001,
...     factor=0.01,
...     total_iters=10
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # First 10 epochs: lr = 0.00001 (conservative)
>>> # Remaining epochs: lr = 0.001 (normal training)

Multiple parameter groups:

>>> # Different base_lr for different layers
>>> scheduler = braintools.optim.ConstantLR(
...     base_lr=[0.1, 0.01],
...     factor=0.1,
...     total_iters=5
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> # Both groups use factor=0.1 for first 5 epochs

Complete training workflow:

>>> import jax.numpy as jnp
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> scheduler = braintools.optim.ConstantLR(
...     base_lr=0.01,
...     factor=0.1,
...     total_iters=5
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(50):
...     # Training step
...     for batch in train_loader:
...         loss = compute_loss(model, batch)
...         grads = jax.grad(compute_loss)(model.states(brainstate.ParamState))
...         optimizer.update(grads)
...
...     scheduler.step()
...     if epoch in [0, 4, 5, 10]:
...         print(f"Epoch {epoch}: lr = {optimizer.current_lr}")

See also

LinearLR

Linearly scale learning rate (smooth transition)

WarmupScheduler

Alternative warmup implementation

ChainedScheduler

Combine multiple schedulers

References

get_lr()[source]#

Calculate learning rate.