AdamW

AdamW#

class braintools.optim.AdamW(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, grad_clip_norm=None, grad_clip_value=None)#

AdamW optimizer with decoupled weight decay regularization.

AdamW modifies the standard Adam algorithm by decoupling the weight decay from the gradient-based update, which has been shown to improve generalization performance.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance.

  • betas (Tuple[float, float]) – Coefficients (beta1, beta2) for computing running averages.

  • eps (float) – Term added to the denominator for numerical stability.

  • weight_decay (float) – Weight decay coefficient (decoupled from gradient).

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping.

  • grad_clip_value (float | None) – Maximum gradient value for clipping.

Notes

Unlike Adam where weight decay is part of the gradient computation, AdamW applies weight decay directly to the parameters:

\[\theta_t = \theta_{t-1} - \alpha (\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1})\]

where \(\lambda\) is the weight decay coefficient.

References

Examples

Basic AdamW usage:

>>> import brainstate
>>> import braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> optimizer = braintools.optim.AdamW(lr=0.001, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

AdamW with scheduler:

>>> scheduler = braintools.optim.CosineAnnealingLR(base_lr=0.001, T_max=100)
>>> optimizer = braintools.optim.AdamW(lr=scheduler, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

Adam

Standard Adam optimizer

SGD

Stochastic gradient descent

default_tx()[source]#

Create AdamW-specific gradient transformation.