Lamb

Lamb#

class braintools.optim.Lamb(lr=0.001, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

LAMB optimizer (Layer-wise Adaptive Moments).

LAMB is designed for large batch training, adapting the learning rate based on the ratio of weight norm to gradient norm for each layer. It enables training with very large batch sizes while maintaining performance comparable to small batch training.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance. Note: LAMB can often use higher learning rates than Adam due to its layer-wise normalization.

  • betas (Tuple[float, float]) – Coefficients for computing running averages of gradient and its square. The first value (beta1) controls the exponential decay rate for the first moment, and the second value (beta2) controls the decay rate for the second moment.

  • eps (float) – Term added to the denominator for numerical stability.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient. LAMB applies weight decay adaptively based on the trust ratio.

  • grad_clip_norm (float | None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.

  • grad_clip_value (float | None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].

Notes

LAMB extends Adam with layer-wise adaptation of the learning rate. The key innovation is the trust ratio mechanism that normalizes updates based on parameter and gradient norms:

\[ \begin{align}\begin{aligned}m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\\hat{m}_t = \frac{m_t}{1 - \beta_1^t}\\\hat{v}_t = \frac{v_t}{1 - \beta_2^t}\\r_t = \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}\end{aligned}\end{align} \]

The trust ratio is computed as:

\[\begin{split}\text{trust_ratio} = \begin{cases} \frac{\|w_t\|}{\|r_t\|} & \text{if } \|w_t\| > 0 \text{ and } \|r_t\| > 0 \\ 1 & \text{otherwise} \end{cases}\end{split}\]

Final update:

\[w_{t+1} = w_t - \alpha \cdot \text{trust_ratio} \cdot r_t\]

LAMB is particularly effective for:

  • Training with batch sizes of 32K or larger

  • BERT and other transformer models

  • Distributed training across multiple GPUs/TPUs

  • Achieving linear scaling of learning rate with batch size

References

Examples

Basic usage with float learning rate:

>>> import brainstate
>>> import braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize LAMB optimizer for large batch training
>>> optimizer = braintools.optim.Lamb(lr=0.002)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Using LAMB with large learning rate for big batch sizes:

>>> # LAMB can handle larger learning rates due to trust ratio
>>> optimizer = braintools.optim.Lamb(lr=0.01, betas=(0.9, 0.999))
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler for warmup and decay:

>>> # Polynomial decay with warmup (common for BERT training)
>>> scheduler = braintools.optim.PolynomialLR(
...     base_lr=0.002,
...     warmup_steps=1000,
...     total_steps=10000,
...     power=1.0
... )
>>> optimizer = braintools.optim.Lamb(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With weight decay for regularization:

>>> # LAMB applies weight decay adaptively
>>> optimizer = braintools.optim.Lamb(
...     lr=0.002,
...     weight_decay=0.01
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Using gradient clipping for stability:

>>> # Clip gradients for training stability
>>> optimizer = braintools.optim.Lamb(
...     lr=0.002,
...     grad_clip_norm=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete training example for large batch:

>>> import jax.numpy as jnp
>>>
>>> # Setup for large batch training
>>> model = brainstate.nn.Sequential(
...     brainstate.nn.Linear(784, 256),
...     brainstate.nn.ReLU(),
...     brainstate.nn.Linear(256, 10)
... )
>>>
>>> # LAMB with settings for large batch
>>> optimizer = braintools.optim.Lamb(
...     lr=0.01,  # Higher lr due to normalization
...     betas=(0.9, 0.999),
...     weight_decay=0.01
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Simulate large batch training
>>> def train_step(batch_x, batch_y):
...     def loss_fn():
...         logits = model(batch_x)
...         return jnp.mean(
...             braintools.metric.softmax_cross_entropy(logits, batch_y)
...         )
...
...     grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
...     optimizer.update(grads)
...     return loss_fn()
>>>
>>> # Large batch size
>>> x = jnp.ones((1024, 784))  # Large batch
>>> y = jnp.zeros((1024, 10))
>>> # loss = train_step(x, y)

See also

Adam

Standard Adam optimizer without layer-wise adaptation

Lars

Layer-wise Adaptive Rate Scaling

AdamW

Adam with decoupled weight decay

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.