Adam

Adam#

class braintools.optim.Adam(lr=0.001, betas=(0.9, 0.999), beta1=None, beta2=None, eps=1e-08, weight_decay=0.0, amsgrad=False, grad_clip_norm=None, grad_clip_value=None)#

Adam (Adaptive Moment Estimation) optimizer.

Adam is an adaptive learning rate optimization algorithm that combines the advantages of AdaGrad and RMSProp. It computes adaptive learning rates for each parameter by maintaining first and second moment estimates of the gradients.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance.

  • betas (Tuple[float, float]) – Coefficients (beta1, beta2) for computing running averages of gradient and its square. beta1 is the exponential decay rate for the first moment estimates, beta2 is for the second moment estimates.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • amsgrad (bool) – Whether to use the AMSGrad variant of Adam.

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping.

  • grad_clip_value (float | None) – Maximum gradient value for clipping.

Notes

The Adam update is computed as:

\[ \begin{align}\begin{aligned}m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\\hat{m}_t = \frac{m_t}{1 - \beta_1^t}\\\hat{v}_t = \frac{v_t}{1 - \beta_2^t}\\\theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}\end{aligned}\end{align} \]

where \(g_t\) is the gradient, \(m_t\) and \(v_t\) are the first and second moment estimates, \(\alpha\) is the learning rate, and \(t\) is the time step.

References

Examples

Basic Adam optimizer:

>>> import brainstate
>>> import braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>> optimizer = braintools.optim.Adam(lr=0.001)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Adam with custom beta values:

>>> optimizer = braintools.optim.Adam(lr=0.001, betas=(0.9, 0.99))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Adam with AMSGrad:

>>> optimizer = braintools.optim.Adam(lr=0.001, amsgrad=True)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Adam with learning rate scheduler:

>>> scheduler = braintools.optim.ExponentialLR(base_lr=0.001, gamma=0.95)
>>> optimizer = braintools.optim.Adam(lr=scheduler, betas=(0.9, 0.999))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     # Training code
...     optimizer.step(grads)
...     scheduler.step()

Adam with gradient clipping:

>>> optimizer = braintools.optim.Adam(lr=0.001, grad_clip_norm=1.0)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

AdamW

Adam with decoupled weight decay

RAdam

Rectified Adam

Nadam

Adam with Nesterov momentum

SGD

Stochastic gradient descent

default_tx()[source]#

Create Adam-specific gradient transformation.