Yogi#
- class braintools.optim.Yogi(lr=0.001, betas=(0.9, 0.999), eps=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#
Yogi optimizer (improvement over Adam).
Yogi is an adaptive learning rate optimizer that addresses some of the limitations of Adam by controlling the increase of the effective learning rate. It uses additive updates instead of multiplicative updates for the second moment estimate, which prevents the effective learning rate from increasing too rapidly.
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance.betas (
Tuple[float,float]) – Coefficients for computing running averages of gradient and its square. The first value (beta1) controls the exponential decay rate for the first moment, and the second value (beta2) controls the decay rate for the second moment.eps (
float) – Term added to the denominator for numerical stability. Note: Yogi uses a larger default epsilon (1e-3) than Adam (1e-8) for better stability.weight_decay (
float) – Weight decay (L2 penalty) coefficient.grad_clip_norm (
float|None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.grad_clip_value (
float|None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].
Notes
Yogi modifies Adam’s second moment update to use an additive approach. The key difference from Adam is in the second moment computation:
First moment (same as Adam):
\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\]Second moment (Yogi’s modification):
\[v_t = v_{t-1} - (1 - \beta_2) \text{sign}(v_{t-1} - g_t^2) \odot g_t^2\]Bias correction:
\[ \begin{align}\begin{aligned}\hat{m}_t = \frac{m_t}{1 - \beta_1^t}\\\hat{v}_t = \frac{v_t}{1 - \beta_2^t}\end{aligned}\end{align} \]Parameter update:
\[\theta_{t+1} = \theta_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}\]The sign-based additive update in the second moment prevents the effective learning rate from increasing when the gradient magnitude decreases, which can happen with Adam’s multiplicative update.
Key advantages of Yogi over Adam:
More stable convergence in some scenarios
Prevents the effective learning rate from growing unboundedly
Better handles changing gradient magnitudes
Often achieves better generalization
Particularly effective for problems with sparse gradients
References
Examples
Basic usage with default parameters:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Initialize Yogi optimizer >>> optimizer = braintools.optim.Yogi(lr=0.001) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With custom beta values:
>>> # Adjust momentum parameters >>> optimizer = braintools.optim.Yogi( ... lr=0.001, ... betas=(0.9, 0.99) # Faster second moment decay ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With larger epsilon for increased stability:
>>> # Yogi is less sensitive to epsilon than Adam >>> optimizer = braintools.optim.Yogi( ... lr=0.001, ... eps=1e-2 # Even larger epsilon for more stability ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With learning rate scheduler:
>>> # Combine with exponential decay >>> scheduler = braintools.optim.ExponentialLR(base_lr=0.01, gamma=0.95) >>> optimizer = braintools.optim.Yogi(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With weight decay for regularization:
>>> # Add L2 regularization >>> optimizer = braintools.optim.Yogi( ... lr=0.001, ... weight_decay=0.01 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete training example for NLP tasks:
>>> import jax.numpy as jnp >>> >>> # Setup language model >>> model = brainstate.nn.Sequential( ... brainstate.nn.Embedding(10000, 256), ... brainstate.nn.LSTM(256, 512), ... brainstate.nn.Linear(512, 10000) ... ) >>> >>> # Yogi works well for NLP with sparse gradients >>> optimizer = braintools.optim.Yogi( ... lr=0.001, ... betas=(0.9, 0.999), ... eps=1e-3 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training step >>> def train_step(tokens, targets): ... def loss_fn(): ... logits = model(tokens) ... return jnp.mean( ... braintools.metric.softmax_cross_entropy(logits, targets) ... ) ... ... grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))() ... optimizer.update(grads) ... return loss_fn() >>> >>> # Train >>> tokens = jnp.ones((32, 50), dtype=jnp.int32) >>> targets = jnp.zeros((32, 10000)) >>> # loss = train_step(tokens, targets)
See also