Lion

Lion#

class braintools.optim.Lion(lr=0.0001, betas=(0.9, 0.99), weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

Lion (EvoLved Sign Momentum) optimizer - Discovered through program search.

Lion is a novel optimizer discovered through large-scale evolutionary program search. It uses sign-based updates for both momentum and parameter updates, making it extremely memory-efficient and computationally simple. Despite its simplicity, Lion achieves competitive or superior performance compared to Adam while using significantly less memory.

The key insight of Lion is using the sign operation, which provides implicit adaptive learning rates and strong regularization effects. Lion typically requires smaller learning rates (3-10x smaller than Adam) but larger weight decay values.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance. If float is provided, it will be automatically converted to a ConstantLR scheduler. Note: Lion typically requires 3-10× smaller learning rates than Adam.

  • betas (Tuple[float, float]) – Coefficients (beta1, beta2) used for computing the interpolation between gradient and momentum. beta1 is used for the update, beta2 is used for momentum tracking. Different from Adam where both betas are exponential decay rates.

  • weight_decay (float) – Weight decay coefficient (L2 penalty). Lion typically uses larger weight decay values than Adam (3-10× larger) due to its implicit regularization.

  • grad_clip_norm (float | None) – Maximum gradient norm for gradient clipping. If None, no gradient norm clipping is applied.

  • grad_clip_value (float | None) – Maximum absolute gradient value for element-wise gradient clipping. If None, no gradient value clipping is applied.

Notes

The Lion update rules are:

\[ \begin{align}\begin{aligned}C_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t\\\theta_{t+1} = \theta_t - \alpha \cdot \text{sign}(C_t)\\M_t = \beta_2 M_{t-1} + (1 - \beta_2) G_t\end{aligned}\end{align} \]

where:

  • \(G_t\) is the gradient at step t

  • \(M_t\) is the momentum (exponential moving average of gradients)

  • \(C_t\) is the interpolation between momentum and current gradient

  • \(\text{sign}(\cdot)\) is the element-wise sign function

  • \(\alpha\) is the learning rate

Key differences from Adam:

  • Sign-based updates: Uses sign(gradient) instead of gradient magnitude

  • Simpler computation: No square root or division operations

  • Less memory: Only stores momentum (not second moment)

  • Different hyperparameters: Smaller lr, larger weight decay

  • Implicit adaptive learning: Sign operation provides adaptation

Key advantages of Lion:

  • Memory efficient: Only 1 state per parameter (vs 2 for Adam)

  • Computationally simple: No expensive operations (sqrt, division)

  • Strong regularization: Sign operation provides implicit regularization

  • Better generalization: Often achieves lower validation loss than Adam

  • Robust: Works well across different architectures and tasks

Lion is particularly well-suited for:

  • Training large language models (LLMs) and vision transformers

  • Memory-constrained environments

  • Tasks requiring strong generalization

  • Replacing Adam/AdamW with better efficiency

Hyperparameter recommendations (relative to Adam):

  • Learning rate: Use 3-10× smaller (e.g., Adam lr=1e-3 → Lion lr=1e-4)

  • Weight decay: Use 3-10× larger (e.g., Adam wd=0.01 → Lion wd=0.1)

  • Batch size: Can use with same batch size as Adam

References

Examples

Basic usage with small learning rate:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize Lion with small lr (3-10x smaller than Adam)
>>> optimizer = braintools.optim.Lion(lr=1e-4)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With larger weight decay (recommended for Lion):

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Lion typically uses larger weight decay than Adam
>>> optimizer = braintools.optim.Lion(lr=1e-4, weight_decay=0.1)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With custom betas:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Custom interpolation coefficients
>>> optimizer = braintools.optim.Lion(lr=1e-4, betas=(0.95, 0.98))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Step decay schedule
>>> scheduler = braintools.optim.StepLR(
...     base_lr=1e-4,
...     step_size=100,
...     gamma=0.5
... )
>>> optimizer = braintools.optim.Lion(lr=scheduler, weight_decay=0.1)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With gradient clipping for stable training:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Clip gradients for stability
>>> optimizer = braintools.optim.Lion(
...     lr=1e-4,
...     weight_decay=0.1,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete configuration for large model training:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Large transformer model
>>> model = brainstate.nn.Linear(1000, 500)
>>>
>>> # Learning rate decay schedule
>>> scheduler = braintools.optim.StepLR(
...     base_lr=1e-4,
...     step_size=100,
...     gamma=0.9
... )
>>>
>>> # Complete Lion configuration
>>> optimizer = braintools.optim.Lion(
...     lr=scheduler,
...     betas=(0.9, 0.99),
...     weight_decay=0.1,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

Adam

Standard adaptive moment estimation

AdamW

Adam with decoupled weight decay

SGD

Stochastic gradient descent with momentum

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.