Lookahead

Lookahead#

class braintools.optim.Lookahead(base_optimizer, sync_period=5, alpha=0.5, lr=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

Lookahead optimizer wrapper.

Lookahead is a meta-optimizer that wraps any standard optimizer and maintains two sets of weights: “fast weights” updated by the base optimizer, and “slow weights” that are periodically synchronized with the fast weights. This approach reduces variance and improves training stability.

Parameters:
  • base_optimizer (GradientTransformation) – The base optimizer to wrap (e.g., SGD, Adam). The base optimizer performs the fast weight updates.

  • sync_period (int) – Number of fast weight update steps before synchronizing with slow weights. Also known as ‘k’ in the paper. Typical values are 5-10.

  • alpha (float) – Slow weights step size. Controls how much the slow weights move toward the fast weights during synchronization. Also known as ‘slow step size’. Range: [0, 1], where 0 means no update and 1 means full update.

  • lr (float | LRScheduler) – Learning rate for the base optimizer. Can be a float (converted to ConstantLR) or any LRScheduler instance.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • grad_clip_norm (float | None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.

  • grad_clip_value (float | None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].

Notes

Lookahead maintains two sets of parameters:

  • Fast weights \(\theta_f\): Updated by the base optimizer every step

  • Slow weights \(\theta_s\): Updated periodically every k steps

The update procedure is:

  1. Fast weight update (every step):

\[\theta_f^{t+1} = \text{BaseOptimizer}(\theta_f^t, g_t)\]
  1. Slow weight update (every k steps):

\[ \begin{align}\begin{aligned}\theta_s^{t+k} = \theta_s^t + \alpha (\theta_f^{t+k} - \theta_s^t)\\\theta_f^{t+k} = \theta_s^{t+k}\end{aligned}\end{align} \]

where:

  • \(k\) is the sync_period

  • \(\alpha\) is the slow step size (alpha parameter)

  • \(g_t\) is the gradient at step t

Benefits of Lookahead:

  • Reduces variance in the optimization trajectory

  • Often achieves better generalization than the base optimizer alone

  • Provides a form of implicit regularization

  • Works with any base optimizer (SGD, Adam, etc.)

  • Minimal computational overhead

The slow weights act as an “anchor” that prevents the fast weights from moving too far in potentially suboptimal directions, leading to more stable and often faster convergence.

References

Examples

Basic usage with SGD as base optimizer:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>> import optax
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Create base optimizer (SGD)
>>> base_opt = optax.sgd(learning_rate=0.1)
>>>
>>> # Wrap with Lookahead
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     sync_period=5,
...     alpha=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With Adam as base optimizer:

>>> # Lookahead + Adam (RAdam paper recommends this combination)
>>> base_opt = optax.adam(learning_rate=0.001)
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     sync_period=6,
...     alpha=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Custom synchronization period:

>>> # Longer sync period for more exploration
>>> base_opt = optax.sgd(learning_rate=0.1, momentum=0.9)
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     sync_period=10,  # Synchronize every 10 steps
...     alpha=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Adjusting slow weights step size:

>>> # Smaller alpha for more conservative slow weight updates
>>> base_opt = optax.adam(learning_rate=0.001)
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     sync_period=5,
...     alpha=0.3  # More conservative than default 0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler:

>>> # Combine with scheduler for dynamic learning rate
>>> scheduler = braintools.optim.ExponentialLR(base_lr=0.1, gamma=0.95)
>>> base_opt = optax.sgd(learning_rate=0.1)
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     lr=scheduler,
...     sync_period=5,
...     alpha=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete training example:

>>> import jax.numpy as jnp
>>>
>>> # Setup
>>> model = brainstate.nn.Sequential(
...     brainstate.nn.Linear(784, 128),
...     brainstate.nn.ReLU(),
...     brainstate.nn.Linear(128, 10)
... )
>>>
>>> # Lookahead with SGD + momentum
>>> base_opt = optax.sgd(learning_rate=0.1, momentum=0.9)
>>> optimizer = braintools.optim.Lookahead(
...     base_optimizer=base_opt,
...     sync_period=5,
...     alpha=0.5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training loop
>>> for epoch in range(10):
...     for batch_x, batch_y in data_loader:
...         def loss_fn():
...             logits = model(batch_x)
...             return jnp.mean(
...                 braintools.metric.softmax_cross_entropy(logits, batch_y)
...             )
...
...         grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
...         optimizer.update(grads)

See also

SGD

Stochastic gradient descent base optimizer

Adam

Adaptive moment estimation base optimizer

RAdam

Rectified Adam (works well with Lookahead)

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.