Nadam

Nadam#

class braintools.optim.Nadam(lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, momentum_decay=0.004, grad_clip_norm=None, grad_clip_value=None)#

Nadam optimizer - Adam with Nesterov accelerated gradient.

Nadam (Nesterov-accelerated Adaptive Moment Estimation) combines Adam with Nesterov momentum. It provides the benefits of both adaptive learning rates and Nesterov’s accelerated gradient method, often leading to faster convergence and better performance than standard Adam.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance.

  • betas (Tuple[float, float]) – Coefficients (beta1, beta2) for computing running averages of gradient and its square. beta1 is the exponential decay rate for the first moment estimates, beta2 is for the second moment estimates.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • momentum_decay (float) – Momentum schedule decay rate for Nadam.

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping.

  • grad_clip_value (float | None) – Maximum gradient value for clipping.

Notes

The Nadam update combines Adam’s adaptive learning rate with Nesterov momentum:

\[ \begin{align}\begin{aligned}m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\\hat{m}_t = \frac{m_t}{1 - \beta_1^{t+1}} + \frac{(1 - \beta_1) g_t}{1 - \beta_1^t}\\\hat{v}_t = \frac{v_t}{1 - \beta_2^t}\\\theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}\end{aligned}\end{align} \]

where the key difference from Adam is in the bias-corrected first moment estimate \(\hat{m}_t\), which incorporates a look-ahead step similar to Nesterov momentum.

References

Examples

Basic Nadam usage:

>>> import brainstate
>>> import braintools
>>> import jax.numpy as jnp
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>> optimizer = braintools.optim.Nadam(lr=0.002)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nadam with custom betas:

>>> optimizer = braintools.optim.Nadam(lr=0.002, betas=(0.9, 0.99))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nadam with weight decay:

>>> optimizer = braintools.optim.Nadam(lr=0.002, weight_decay=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nadam with learning rate scheduler:

>>> scheduler = braintools.optim.ExponentialLR(gamma=0.95)
>>> optimizer = braintools.optim.Nadam(lr=scheduler, betas=(0.9, 0.999))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     # Training code
...     optimizer.step(grads)
...     scheduler.step()

Nadam with gradient clipping:

>>> optimizer = braintools.optim.Nadam(
...     lr=0.002,
...     grad_clip_norm=1.0,
...     grad_clip_value=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Comparison with Adam - Nadam often converges faster:

>>> # Compare convergence
>>> model1 = brainstate.nn.Linear(10, 5)
>>> model2 = brainstate.nn.Linear(10, 5)
>>>
>>> adam = braintools.optim.Adam(lr=0.002)
>>> nadam = braintools.optim.Nadam(lr=0.002)
>>>
>>> adam.register_trainable_weights(model1.states(brainstate.ParamState))
>>> nadam.register_trainable_weights(model2.states(brainstate.ParamState))
>>> # Nadam typically shows faster initial convergence

See also

Adam

Standard Adam optimizer

Adamax

Adam variant with infinity norm

RAdam

Rectified Adam with variance adaptation

SGD

Stochastic gradient descent with Nesterov momentum option

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.