AdaBelief

AdaBelief#

class braintools.optim.AdaBelief(lr=0.001, betas=(0.9, 0.999), eps=1e-16, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

AdaBelief optimizer - Adapts step size according to belief in gradient direction.

AdaBelief is an adaptive learning rate optimizer that adapts the step size according to the “belief” in the gradient direction. Unlike Adam which adapts based on gradient magnitudes, AdaBelief adapts based on the variance of the prediction error (gradient - momentum).

The key insight is that when the gradient and momentum are aligned (high belief), the optimizer should take larger steps. When they diverge (low belief), smaller steps should be taken. This leads to faster convergence and better generalization compared to Adam.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance. If float is provided, it will be automatically converted to a ConstantLR scheduler.

  • betas (Tuple[float, float]) – Coefficients (beta1, beta2) used for computing running averages of gradient and its variance. beta1 is the exponential decay rate for the first moment, beta2 is the exponential decay rate for the second moment (variance).

  • eps (float) – Term added to the denominator for numerical stability. AdaBelief uses a smaller epsilon than Adam by default.

  • weight_decay (float) – Weight decay coefficient (L2 penalty). When greater than 0, applies L2 regularization to the parameters.

  • grad_clip_norm (float | None) – Maximum gradient norm for gradient clipping. If None, no gradient norm clipping is applied.

  • grad_clip_value (float | None) – Maximum absolute gradient value for element-wise gradient clipping. If None, no gradient value clipping is applied.

Notes

The AdaBelief update rules are:

\[ \begin{align}\begin{aligned}M_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t\\S_t = \beta_2 S_{t-1} + (1 - \beta_2) (G_t - M_t)^2 + \epsilon\\\hat{M}_t = \frac{M_t}{1 - \beta_1^t}\\\hat{S}_t = \frac{S_t}{1 - \beta_2^t}\\\theta_{t+1} = \theta_t - \alpha \frac{\hat{M}_t}{\sqrt{\hat{S}_t} + \epsilon}\end{aligned}\end{align} \]

where:

  • \(G_t\) is the gradient at step t

  • \(M_t\) is the first moment (exponential moving average of gradients)

  • \(S_t\) is the “belief” - variance of gradient prediction error

  • \((G_t - M_t)^2\) measures the deviation between gradient and momentum

  • \(\hat{M}_t, \hat{S}_t\) are bias-corrected estimates

  • \(\alpha\) is the learning rate

The key difference from Adam is the second moment estimation:

  • Adam: \(V_t = \beta_2 V_{t-1} + (1 - \beta_2) G_t^2\) (gradient magnitude)

  • AdaBelief: \(S_t = \beta_2 S_{t-1} + (1 - \beta_2) (G_t - M_t)^2\) (gradient variance)

Key advantages of AdaBelief:

  • Better generalization: Adapts based on gradient variance, not magnitude

  • Fast convergence: Takes larger steps when gradient is reliable

  • Stable training: Takes smaller steps when gradient is noisy

  • Automatic adaptation: No need for extensive hyperparameter tuning

  • Works across domains: Effective for image, language, and RL tasks

AdaBelief is particularly well-suited for:

  • Training deep neural networks with complex loss landscapes

  • Problems where Adam overfits or converges slowly

  • Transfer learning and fine-tuning tasks

  • Reinforcement learning with noisy gradients

  • Training with small batch sizes

References

Examples

Basic usage:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize AdaBelief
>>> optimizer = braintools.optim.AdaBelief(lr=0.001)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Custom betas for different momentum and variance decay:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Faster momentum decay, slower variance decay
>>> optimizer = braintools.optim.AdaBelief(lr=0.001, betas=(0.8, 0.999))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler for gradual decay:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Learning rate decays every 30 epochs
>>> scheduler = braintools.optim.StepLR(base_lr=0.001, step_size=30, gamma=0.5)
>>> optimizer = braintools.optim.AdaBelief(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With weight decay for regularization:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Add L2 regularization
>>> optimizer = braintools.optim.AdaBelief(lr=0.001, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With gradient clipping for stable training:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Clip gradients by global norm
>>> optimizer = braintools.optim.AdaBelief(
...     lr=0.001,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete configuration for deep learning:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Large model
>>> model = brainstate.nn.Linear(1000, 500)
>>>
>>> # Learning rate schedule
>>> scheduler = braintools.optim.StepLR(base_lr=0.001, step_size=20, gamma=0.5)
>>>
>>> # Complete AdaBelief configuration
>>> optimizer = braintools.optim.AdaBelief(
...     lr=scheduler,
...     betas=(0.9, 0.999),
...     eps=1e-16,
...     weight_decay=0.0001,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

Adam

Standard adaptive moment estimation

Yogi

Adam variant with additive second moment updates

RAdam

Rectified Adam with warmup

AdamW

Adam with decoupled weight decay

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.