SGD#
- class braintools.optim.SGD(lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, grad_clip_norm=None, grad_clip_value=None)#
Stochastic Gradient Descent (SGD) optimizer with momentum and weight decay.
Implements the standard SGD algorithm with optional momentum, Nesterov momentum, and weight decay (L2 regularization).
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float or LRScheduler instance.momentum (
float) – Momentum factor. Set to 0 for vanilla SGD.weight_decay (
float) – Weight decay (L2 penalty) coefficient.nesterov (
bool) – Whether to use Nesterov momentum.grad_clip_norm (
float|None) – Maximum gradient norm for clipping.grad_clip_value (
float|None) – Maximum gradient value for clipping.
Notes
The SGD update with momentum is computed as:
\[ \begin{align}\begin{aligned}v_{t+1} = \mu v_t + g_t\\\theta_{t+1} = \theta_t - \alpha v_{t+1}\end{aligned}\end{align} \]where \(\mu\) is the momentum, \(g_t\) is the gradient at step t, \(\alpha\) is the learning rate, and \(\theta\) are the parameters.
With Nesterov momentum, the update becomes:
\[ \begin{align}\begin{aligned}v_{t+1} = \mu v_t + g_t\\\theta_{t+1} = \theta_t - \alpha (\mu v_{t+1} + g_t)\end{aligned}\end{align} \]Examples
Basic SGD without momentum:
>>> import brainstate >>> import braintools >>> import jax.numpy as jnp >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> optimizer = braintools.optim.SGD(lr=0.01) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
SGD with momentum:
>>> optimizer = braintools.optim.SGD(lr=0.01, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
SGD with Nesterov momentum:
>>> optimizer = braintools.optim.SGD(lr=0.01, momentum=0.9, nesterov=True) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
SGD with learning rate scheduling:
>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1) >>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> for epoch in range(100): ... # Training code here ... optimizer.step(grads) ... if (epoch + 1) % epoch_size == 0: ... scheduler.step()
See also