Lars#
- class braintools.optim.Lars(lr=1.0, momentum=0.9, weight_decay=0.0, trust_coefficient=0.001, eps=1e-08, grad_clip_norm=None, grad_clip_value=None)#
LARS optimizer (Layer-wise Adaptive Rate Scaling).
LARS adapts the learning rate for each layer based on the ratio between the norm of weights and the norm of gradients. This allows for stable training with very large batch sizes by preventing layers from having their weights change too rapidly.
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance. Note: LARS typically uses higher base learning rates than SGD due to its layer-wise scaling.momentum (
float) – Momentum factor for the moving average of gradients. Higher values provide more smoothing of gradient updates.weight_decay (
float) – Weight decay (L2 penalty) coefficient. Applied before the trust ratio computation.trust_coefficient (
float) – Trust coefficient (eta) that scales the local learning rate. Controls how much the layer-wise adaptation affects the update.eps (
float) – Term added to denominators for numerical stability.grad_clip_norm (
float|None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.grad_clip_value (
float|None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].
Notes
LARS computes a local learning rate for each layer based on the trust ratio. The update rule with momentum is:
\[g_t = \nabla L(w_t) + \lambda w_t\]where \(\lambda\) is the weight decay coefficient.
The local learning rate is computed as:
\[\text{local_lr} = \eta \times \frac{\|w_t\|}{\|g_t\| + \epsilon}\]The momentum update is then:
\[ \begin{align}\begin{aligned}v_t = \mu v_{t-1} + \text{local_lr} \times g_t\\w_{t+1} = w_t - \alpha \times v_t\end{aligned}\end{align} \]where \(\eta\) is the trust coefficient, \(\mu\) is momentum, and \(\alpha\) is the global learning rate.
Key properties of LARS:
Enables linear scaling of batch size with learning rate
Particularly effective for training ResNets and other CNNs
Maintains different learning rates for different layers
Prevents rapid weight changes in any single layer
The trust ratio mechanism ensures that:
Layers with large weights get smaller updates
Layers with large gradients get smaller updates
The relative change in weights is controlled
References
Examples
Basic usage with momentum:
>>> import brainstate >>> import braintools >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Initialize LARS optimizer >>> optimizer = braintools.optim.Lars(lr=0.1, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With custom trust coefficient for fine-tuning:
>>> # Smaller trust coefficient for more conservative updates >>> optimizer = braintools.optim.Lars( ... lr=0.1, ... momentum=0.9, ... trust_coefficient=0.0001 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Using learning rate scheduler:
>>> # Cosine annealing schedule >>> scheduler = braintools.optim.CosineAnnealingLR( ... base_lr=0.1, ... T_max=100 ... ) >>> optimizer = braintools.optim.Lars(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With weight decay for regularization:
>>> # LARS with weight decay >>> optimizer = braintools.optim.Lars( ... lr=0.1, ... momentum=0.9, ... weight_decay=5e-4 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Large batch training configuration:
>>> # Configuration for large batch training >>> # Linear scaling rule: lr = base_lr * (batch_size / base_batch) >>> batch_size = 4096 >>> base_batch = 256 >>> base_lr = 0.1 >>> scaled_lr = base_lr * (batch_size / base_batch) >>> >>> optimizer = braintools.optim.Lars( ... lr=scaled_lr, ... momentum=0.9, ... weight_decay=5e-4, ... trust_coefficient=0.001 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete training example with CNN:
>>> import jax.numpy as jnp >>> >>> # Setup CNN model >>> model = brainstate.nn.Sequential( ... brainstate.nn.Conv2d(3, 64, kernel_size=3, padding=1), ... brainstate.nn.ReLU(), ... brainstate.nn.MaxPool2d(2), ... brainstate.nn.Conv2d(64, 128, kernel_size=3, padding=1), ... brainstate.nn.ReLU(), ... brainstate.nn.AdaptiveAvgPool2d((1, 1)), ... brainstate.nn.Flatten(), ... brainstate.nn.Linear(128, 10) ... ) >>> >>> # LARS for large batch CNN training >>> optimizer = braintools.optim.Lars( ... lr=1.6, # Large lr for batch size 4096 ... momentum=0.9, ... weight_decay=1e-4, ... trust_coefficient=0.001 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training step >>> def train_step(images, labels): ... def loss_fn(): ... logits = model(images) ... return jnp.mean( ... braintools.metric.softmax_cross_entropy(logits, labels) ... ) ... ... grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))() ... optimizer.update(grads) ... return loss_fn() >>> >>> # Large batch training >>> x = jnp.ones((256, 3, 32, 32)) # Large batch >>> y = jnp.zeros((256, 10)) >>> # loss = train_step(x, y)
See also