Lookahead#
- class braintools.optim.Lookahead(base_optimizer, sync_period=5, alpha=0.5, lr=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#
Lookahead optimizer wrapper.
Lookahead is a meta-optimizer that wraps any standard optimizer and maintains two sets of weights: “fast weights” updated by the base optimizer, and “slow weights” that are periodically synchronized with the fast weights. This approach reduces variance and improves training stability.
- Parameters:
base_optimizer (
GradientTransformation) – The base optimizer to wrap (e.g., SGD, Adam). The base optimizer performs the fast weight updates.sync_period (
int) – Number of fast weight update steps before synchronizing with slow weights. Also known as ‘k’ in the paper. Typical values are 5-10.alpha (
float) – Slow weights step size. Controls how much the slow weights move toward the fast weights during synchronization. Also known as ‘slow step size’. Range: [0, 1], where 0 means no update and 1 means full update.lr (
float|LRScheduler) – Learning rate for the base optimizer. Can be a float (converted to ConstantLR) or any LRScheduler instance.weight_decay (
float) – Weight decay (L2 penalty) coefficient.grad_clip_norm (
float|None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.grad_clip_value (
float|None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].
Notes
Lookahead maintains two sets of parameters:
Fast weights \(\theta_f\): Updated by the base optimizer every step
Slow weights \(\theta_s\): Updated periodically every k steps
The update procedure is:
Fast weight update (every step):
\[\theta_f^{t+1} = \text{BaseOptimizer}(\theta_f^t, g_t)\]Slow weight update (every k steps):
\[ \begin{align}\begin{aligned}\theta_s^{t+k} = \theta_s^t + \alpha (\theta_f^{t+k} - \theta_s^t)\\\theta_f^{t+k} = \theta_s^{t+k}\end{aligned}\end{align} \]where:
\(k\) is the sync_period
\(\alpha\) is the slow step size (alpha parameter)
\(g_t\) is the gradient at step t
Benefits of Lookahead:
Reduces variance in the optimization trajectory
Often achieves better generalization than the base optimizer alone
Provides a form of implicit regularization
Works with any base optimizer (SGD, Adam, etc.)
Minimal computational overhead
The slow weights act as an “anchor” that prevents the fast weights from moving too far in potentially suboptimal directions, leading to more stable and often faster convergence.
References
Examples
Basic usage with SGD as base optimizer:
>>> import brainstate as brainstate >>> import braintools as braintools >>> import optax >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Create base optimizer (SGD) >>> base_opt = optax.sgd(learning_rate=0.1) >>> >>> # Wrap with Lookahead >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... sync_period=5, ... alpha=0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With Adam as base optimizer:
>>> # Lookahead + Adam (RAdam paper recommends this combination) >>> base_opt = optax.adam(learning_rate=0.001) >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... sync_period=6, ... alpha=0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Custom synchronization period:
>>> # Longer sync period for more exploration >>> base_opt = optax.sgd(learning_rate=0.1, momentum=0.9) >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... sync_period=10, # Synchronize every 10 steps ... alpha=0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Adjusting slow weights step size:
>>> # Smaller alpha for more conservative slow weight updates >>> base_opt = optax.adam(learning_rate=0.001) >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... sync_period=5, ... alpha=0.3 # More conservative than default 0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With learning rate scheduler:
>>> # Combine with scheduler for dynamic learning rate >>> scheduler = braintools.optim.ExponentialLR(base_lr=0.1, gamma=0.95) >>> base_opt = optax.sgd(learning_rate=0.1) >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... lr=scheduler, ... sync_period=5, ... alpha=0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete training example:
>>> import jax.numpy as jnp >>> >>> # Setup >>> model = brainstate.nn.Sequential( ... brainstate.nn.Linear(784, 128), ... brainstate.nn.ReLU(), ... brainstate.nn.Linear(128, 10) ... ) >>> >>> # Lookahead with SGD + momentum >>> base_opt = optax.sgd(learning_rate=0.1, momentum=0.9) >>> optimizer = braintools.optim.Lookahead( ... base_optimizer=base_opt, ... sync_period=5, ... alpha=0.5 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training loop >>> for epoch in range(10): ... for batch_x, batch_y in data_loader: ... def loss_fn(): ... logits = model(batch_x) ... return jnp.mean( ... braintools.metric.softmax_cross_entropy(logits, batch_y) ... ) ... ... grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))() ... optimizer.update(grads)
See also