Yogi

Yogi#

class braintools.optim.Yogi(lr=0.001, betas=(0.9, 0.999), eps=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

Yogi optimizer (improvement over Adam).

Yogi is an adaptive learning rate optimizer that addresses some of the limitations of Adam by controlling the increase of the effective learning rate. It uses additive updates instead of multiplicative updates for the second moment estimate, which prevents the effective learning rate from increasing too rapidly.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance.

  • betas (Tuple[float, float]) – Coefficients for computing running averages of gradient and its square. The first value (beta1) controls the exponential decay rate for the first moment, and the second value (beta2) controls the decay rate for the second moment.

  • eps (float) – Term added to the denominator for numerical stability. Note: Yogi uses a larger default epsilon (1e-3) than Adam (1e-8) for better stability.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • grad_clip_norm (float | None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.

  • grad_clip_value (float | None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].

Notes

Yogi modifies Adam’s second moment update to use an additive approach. The key difference from Adam is in the second moment computation:

First moment (same as Adam):

\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\]

Second moment (Yogi’s modification):

\[v_t = v_{t-1} - (1 - \beta_2) \text{sign}(v_{t-1} - g_t^2) \odot g_t^2\]

Bias correction:

\[ \begin{align}\begin{aligned}\hat{m}_t = \frac{m_t}{1 - \beta_1^t}\\\hat{v}_t = \frac{v_t}{1 - \beta_2^t}\end{aligned}\end{align} \]

Parameter update:

\[\theta_{t+1} = \theta_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}\]

The sign-based additive update in the second moment prevents the effective learning rate from increasing when the gradient magnitude decreases, which can happen with Adam’s multiplicative update.

Key advantages of Yogi over Adam:

  • More stable convergence in some scenarios

  • Prevents the effective learning rate from growing unboundedly

  • Better handles changing gradient magnitudes

  • Often achieves better generalization

  • Particularly effective for problems with sparse gradients

References

Examples

Basic usage with default parameters:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize Yogi optimizer
>>> optimizer = braintools.optim.Yogi(lr=0.001)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With custom beta values:

>>> # Adjust momentum parameters
>>> optimizer = braintools.optim.Yogi(
...     lr=0.001,
...     betas=(0.9, 0.99)  # Faster second moment decay
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With larger epsilon for increased stability:

>>> # Yogi is less sensitive to epsilon than Adam
>>> optimizer = braintools.optim.Yogi(
...     lr=0.001,
...     eps=1e-2  # Even larger epsilon for more stability
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler:

>>> # Combine with exponential decay
>>> scheduler = braintools.optim.ExponentialLR(base_lr=0.01, gamma=0.95)
>>> optimizer = braintools.optim.Yogi(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With weight decay for regularization:

>>> # Add L2 regularization
>>> optimizer = braintools.optim.Yogi(
...     lr=0.001,
...     weight_decay=0.01
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete training example for NLP tasks:

>>> import jax.numpy as jnp
>>>
>>> # Setup language model
>>> model = brainstate.nn.Sequential(
...     brainstate.nn.Embedding(10000, 256),
...     brainstate.nn.LSTM(256, 512),
...     brainstate.nn.Linear(512, 10000)
... )
>>>
>>> # Yogi works well for NLP with sparse gradients
>>> optimizer = braintools.optim.Yogi(
...     lr=0.001,
...     betas=(0.9, 0.999),
...     eps=1e-3
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training step
>>> def train_step(tokens, targets):
...     def loss_fn():
...         logits = model(tokens)
...         return jnp.mean(
...             braintools.metric.softmax_cross_entropy(logits, targets)
...         )
...
...     grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
...     optimizer.update(grads)
...     return loss_fn()
>>>
>>> # Train
>>> tokens = jnp.ones((32, 50), dtype=jnp.int32)
>>> targets = jnp.zeros((32, 10000))
>>> # loss = train_step(tokens, targets)

See also

Adam

Standard adaptive moment estimation

AdamW

Adam with decoupled weight decay

RAdam

Rectified Adam with variance rectification

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.