MomentumNesterov#

class braintools.optim.MomentumNesterov(lr=0.001, momentum=0.9, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

Nesterov Momentum optimizer.

Implements Nesterov’s accelerated gradient method, which looks ahead by extrapolating the momentum term before computing the gradient. This often leads to faster convergence compared to standard momentum.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance.

  • momentum (float) – Momentum factor. The fraction of the gradient to retain from previous steps.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping.

  • grad_clip_value (float | None) – Maximum gradient value for clipping.

Notes

The Nesterov momentum update is computed as:

\[ \begin{align}\begin{aligned}v_{t+1} = \mu v_t + g_t\\\theta_{t+1} = \theta_t - \alpha (\mu v_{t+1} + g_t)\end{aligned}\end{align} \]

This is equivalent to first making a momentum step, then computing the gradient at the resulting position, which provides a “lookahead” effect.

where \(\mu\) is the momentum factor, \(g_t\) is the gradient at step t, \(\alpha\) is the learning rate, \(v_t\) is the velocity, and \(\theta\) are the parameters.

References

Examples

Basic Nesterov Momentum optimizer:

>>> import brainstate
>>> import braintools
>>> import jax.numpy as jnp
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>> optimizer = braintools.optim.MomentumNesterov(lr=0.01, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nesterov Momentum with weight decay:

>>> optimizer = braintools.optim.MomentumNesterov(lr=0.01, momentum=0.9, weight_decay=0.0001)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nesterov Momentum with gradient clipping:

>>> optimizer = braintools.optim.MomentumNesterov(
...     lr=0.01,
...     momentum=0.9,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Nesterov Momentum with learning rate scheduling:

>>> scheduler = braintools.optim.ExponentialLR(base_lr=0.01, gamma=0.95)
>>> optimizer = braintools.optim.MomentumNesterov(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     # Training code here
...     optimizer.step(grads)
...     scheduler.step()

See also

Momentum

Standard momentum optimizer

SGD

Stochastic gradient descent with optional momentum and Nesterov

Adam

Adam optimizer with adaptive learning rates

default_tx()[source]#

Create Nesterov Momentum-specific gradient transformation.