ExponentialDecayLR#
- class braintools.optim.ExponentialDecayLR(base_lr=0.001, decay_steps=1000, decay_rate=0.96, transition_begin=0, staircase=False, end_value=None, last_epoch=0)#
Exponential decay learning rate scheduler with step-based control.
ExponentialDecayLR implements optax’s exponential_decay schedule, providing more fine-grained control compared to ExponentialLR. It supports transition steps, staircase mode, delayed start, and bounded decay, making it suitable for step-level (rather than epoch-level) learning rate control.
- Parameters:
base_lr (
float|List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.decay_steps (
int) – Number of steps over which to apply the decay. Must be positive.decay_rate (
float) – The decay rate. Must not be zero. Values < 1 create decay, values > 1 create growth. Typical values: 0.96-0.99 for slow decay, 0.9-0.95 for moderate decay.transition_begin (
int) – Number of steps to wait before starting decay. The learning rate is held at base_lr for this many steps. Default: 0.staircase (
bool) – If True, decay happens at discrete intervals (step-wise). If False, decay is continuous. Default: False.end_value (
float|None) – Optional bound for the decayed value. When decay_rate < 1, acts as a lower bound. When decay_rate > 1, acts as an upper bound. Default: None (no bound).last_epoch (
int) – The index of the last epoch. Default: 0.
Notes
The learning rate is computed based on step count. When
step >= transition_begin:Continuous mode (staircase=False):
\[ \begin{align}\begin{aligned}\text{rate\_factor} = \frac{\text{step} - \text{transition\_begin}}{\text{transition\_steps}}\\\eta = \text{init\_value} \times \text{decay\_rate}^{\text{rate\_factor}}\end{aligned}\end{align} \]Staircase mode (staircase=True):
\[ \begin{align}\begin{aligned}\text{rate\_factor} = \left\lfloor\frac{\text{step} - \text{transition\_begin}}{\text{transition\_steps}}\right\rfloor\\\eta = \text{init\_value} \times \text{decay\_rate}^{\text{rate\_factor}}\end{aligned}\end{align} \]Before
transition_beginsteps, the learning rate is held constant atbase_lr.Key differences from ExponentialLR:
Step-based instead of epoch-based control
Configurable transition period (decay_steps)
Optional delayed start (transition_begin)
Staircase mode for discrete decay steps
Bounded decay with end_value
When to use:
When you need step-level (not epoch-level) learning rate control
For fine-grained decay schedules
When you want to delay decay start
For bounded decay with minimum/maximum values
In scenarios requiring staircase (discrete) decay
Examples
Basic continuous exponential decay:
>>> import braintools >>> import brainstate >>> >>> model = brainstate.nn.Linear(10, 5) >>> # Decay by 0.96 every 1000 steps >>> scheduler = braintools.optim.ExponentialDecayLR( ... base_lr=0.1, ... decay_steps=1000, ... decay_rate=0.96 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in range(5000): ... optimizer.step(grads) ... scheduler.step() ... if step % 1000 == 0: ... print(f"Step {step}: lr = {optimizer.current_lr:.6f}") Step 0: lr = 0.100000 Step 1000: lr = 0.096000 # 0.1 * 0.96^1 Step 2000: lr = 0.092160 # 0.1 * 0.96^2 Step 3000: lr = 0.088474 # 0.1 * 0.96^3 Step 4000: lr = 0.084935 # 0.1 * 0.96^4
Staircase mode (discrete decay steps):
>>> # Decay every 1000 steps with staircase mode >>> scheduler = braintools.optim.ExponentialDecayLR( ... base_lr=0.1, ... decay_steps=1000, ... decay_rate=0.5, ... staircase=True ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in [0, 500, 1000, 1500, 2000, 2500, 3000]: ... for _ in range(step - scheduler.last_epoch.value): ... scheduler.step() ... print(f"Step {step}: lr = {optimizer.current_lr:.6f}") Step 0: lr = 0.100000 Step 500: lr = 0.100000 # Still in first interval Step 1000: lr = 0.050000 # Drops at step 1000 Step 1500: lr = 0.050000 # Constant until step 2000 Step 2000: lr = 0.025000 # Drops at step 2000 Step 2500: lr = 0.025000 # Constant until step 3000 Step 3000: lr = 0.012500 # Drops at step 3000
Delayed decay start:
>>> # Hold LR constant for 2000 steps, then start decay >>> scheduler = braintools.optim.ExponentialDecayLR( ... base_lr=0.01, ... decay_steps=1000, ... decay_rate=0.95, ... transition_begin=2000 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in [0, 1000, 2000, 3000, 4000]: ... for _ in range(step - scheduler.last_epoch.value): ... scheduler.step() ... print(f"Step {step}: lr = {optimizer.current_lr:.6f}") Step 0: lr = 0.010000 # Held constant Step 1000: lr = 0.010000 # Held constant Step 2000: lr = 0.010000 # Decay starts here Step 3000: lr = 0.009500 # 0.01 * 0.95^1 Step 4000: lr = 0.009025 # 0.01 * 0.95^2
Bounded decay with end_value:
>>> # Decay but don't go below 0.001 >>> scheduler = braintools.optim.ExponentialDecayLR( ... base_lr=0.1, ... decay_steps=500, ... decay_rate=0.9, ... end_value=0.001 ... ) >>> optimizer = braintools.optim.SGD(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for step in range(0, 5000, 500): ... for _ in range(step - scheduler.last_epoch.value): ... scheduler.step() ... print(f"Step {step}: lr = {optimizer.current_lr:.6f}") # LR decays but stops at end_value
Fine-tuning with slow decay:
>>> # Very gentle step-based decay for fine-tuning >>> scheduler = braintools.optim.ExponentialDecayLR( ... base_lr=1e-4, ... decay_steps=100, ... decay_rate=0.99, ... transition_begin=500 ... ) >>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Fine-tune for many steps >>> for step in range(10000): ... optimizer.step(grads) ... scheduler.step()
Comparison with ExponentialLR:
>>> # ExponentialLR: epoch-based, simple gamma decay >>> exp_lr = braintools.optim.ExponentialLR(base_lr=0.1, gamma=0.95) >>> >>> # ExponentialDecayLR: step-based, configurable transition >>> exp_decay_lr = braintools.optim.ExponentialDecayLR( ... base_lr=0.1, ... decay_steps=1, ... decay_rate=0.95 ... ) >>> # These are equivalent when decay_steps=1 and called every epoch
See also
ExponentialLRSimple exponential learning rate decay (epoch-based)
StepLRStep-wise learning rate decay
CosineAnnealingLRCosine annealing schedule
References