ExponentialDecayLR

ExponentialDecayLR#

class braintools.optim.ExponentialDecayLR(base_lr=0.001, decay_steps=1000, decay_rate=0.96, transition_begin=0, staircase=False, end_value=None, last_epoch=0)#

Exponential decay learning rate scheduler with step-based control.

ExponentialDecayLR implements optax’s exponential_decay schedule, providing more fine-grained control compared to ExponentialLR. It supports transition steps, staircase mode, delayed start, and bounded decay, making it suitable for step-level (rather than epoch-level) learning rate control.

Parameters:
  • base_lr (float | List[float]) – Initial learning rate(s). Can be a single float or a list of floats for multiple parameter groups. Default: 1e-3.

  • decay_steps (int) – Number of steps over which to apply the decay. Must be positive.

  • decay_rate (float) – The decay rate. Must not be zero. Values < 1 create decay, values > 1 create growth. Typical values: 0.96-0.99 for slow decay, 0.9-0.95 for moderate decay.

  • transition_begin (int) – Number of steps to wait before starting decay. The learning rate is held at base_lr for this many steps. Default: 0.

  • staircase (bool) – If True, decay happens at discrete intervals (step-wise). If False, decay is continuous. Default: False.

  • end_value (float | None) – Optional bound for the decayed value. When decay_rate < 1, acts as a lower bound. When decay_rate > 1, acts as an upper bound. Default: None (no bound).

  • last_epoch (int) – The index of the last epoch. Default: 0.

Notes

The learning rate is computed based on step count. When step >= transition_begin:

Continuous mode (staircase=False):

\[ \begin{align}\begin{aligned}\text{rate\_factor} = \frac{\text{step} - \text{transition\_begin}}{\text{transition\_steps}}\\\eta = \text{init\_value} \times \text{decay\_rate}^{\text{rate\_factor}}\end{aligned}\end{align} \]

Staircase mode (staircase=True):

\[ \begin{align}\begin{aligned}\text{rate\_factor} = \left\lfloor\frac{\text{step} - \text{transition\_begin}}{\text{transition\_steps}}\right\rfloor\\\eta = \text{init\_value} \times \text{decay\_rate}^{\text{rate\_factor}}\end{aligned}\end{align} \]

Before transition_begin steps, the learning rate is held constant at base_lr.

Key differences from ExponentialLR:

  • Step-based instead of epoch-based control

  • Configurable transition period (decay_steps)

  • Optional delayed start (transition_begin)

  • Staircase mode for discrete decay steps

  • Bounded decay with end_value

When to use:

  • When you need step-level (not epoch-level) learning rate control

  • For fine-grained decay schedules

  • When you want to delay decay start

  • For bounded decay with minimum/maximum values

  • In scenarios requiring staircase (discrete) decay

Examples

Basic continuous exponential decay:

>>> import braintools
>>> import brainstate
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> # Decay by 0.96 every 1000 steps
>>> scheduler = braintools.optim.ExponentialDecayLR(
...     base_lr=0.1,
...     decay_steps=1000,
...     decay_rate=0.96
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in range(5000):
...     optimizer.step(grads)
...     scheduler.step()
...     if step % 1000 == 0:
...         print(f"Step {step}: lr = {optimizer.current_lr:.6f}")
Step 0: lr = 0.100000
Step 1000: lr = 0.096000  # 0.1 * 0.96^1
Step 2000: lr = 0.092160  # 0.1 * 0.96^2
Step 3000: lr = 0.088474  # 0.1 * 0.96^3
Step 4000: lr = 0.084935  # 0.1 * 0.96^4

Staircase mode (discrete decay steps):

>>> # Decay every 1000 steps with staircase mode
>>> scheduler = braintools.optim.ExponentialDecayLR(
...     base_lr=0.1,
...     decay_steps=1000,
...     decay_rate=0.5,
...     staircase=True
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in [0, 500, 1000, 1500, 2000, 2500, 3000]:
...     for _ in range(step - scheduler.last_epoch.value):
...         scheduler.step()
...     print(f"Step {step}: lr = {optimizer.current_lr:.6f}")
Step 0: lr = 0.100000
Step 500: lr = 0.100000   # Still in first interval
Step 1000: lr = 0.050000  # Drops at step 1000
Step 1500: lr = 0.050000  # Constant until step 2000
Step 2000: lr = 0.025000  # Drops at step 2000
Step 2500: lr = 0.025000  # Constant until step 3000
Step 3000: lr = 0.012500  # Drops at step 3000

Delayed decay start:

>>> # Hold LR constant for 2000 steps, then start decay
>>> scheduler = braintools.optim.ExponentialDecayLR(
...     base_lr=0.01,
...     decay_steps=1000,
...     decay_rate=0.95,
...     transition_begin=2000
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in [0, 1000, 2000, 3000, 4000]:
...     for _ in range(step - scheduler.last_epoch.value):
...         scheduler.step()
...     print(f"Step {step}: lr = {optimizer.current_lr:.6f}")
Step 0: lr = 0.010000     # Held constant
Step 1000: lr = 0.010000  # Held constant
Step 2000: lr = 0.010000  # Decay starts here
Step 3000: lr = 0.009500  # 0.01 * 0.95^1
Step 4000: lr = 0.009025  # 0.01 * 0.95^2

Bounded decay with end_value:

>>> # Decay but don't go below 0.001
>>> scheduler = braintools.optim.ExponentialDecayLR(
...     base_lr=0.1,
...     decay_steps=500,
...     decay_rate=0.9,
...     end_value=0.001
... )
>>> optimizer = braintools.optim.SGD(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for step in range(0, 5000, 500):
...     for _ in range(step - scheduler.last_epoch.value):
...         scheduler.step()
...     print(f"Step {step}: lr = {optimizer.current_lr:.6f}")
# LR decays but stops at end_value

Fine-tuning with slow decay:

>>> # Very gentle step-based decay for fine-tuning
>>> scheduler = braintools.optim.ExponentialDecayLR(
...     base_lr=1e-4,
...     decay_steps=100,
...     decay_rate=0.99,
...     transition_begin=500
... )
>>> optimizer = braintools.optim.Adam(lr=scheduler, weight_decay=1e-5)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Fine-tune for many steps
>>> for step in range(10000):
...     optimizer.step(grads)
...     scheduler.step()

Comparison with ExponentialLR:

>>> # ExponentialLR: epoch-based, simple gamma decay
>>> exp_lr = braintools.optim.ExponentialLR(base_lr=0.1, gamma=0.95)
>>>
>>> # ExponentialDecayLR: step-based, configurable transition
>>> exp_decay_lr = braintools.optim.ExponentialDecayLR(
...     base_lr=0.1,
...     decay_steps=1,
...     decay_rate=0.95
... )
>>> # These are equivalent when decay_steps=1 and called every epoch

See also

ExponentialLR

Simple exponential learning rate decay (epoch-based)

StepLR

Step-wise learning rate decay

CosineAnnealingLR

Cosine annealing schedule

References

get_lr()[source]#

Calculate learning rate.