SGD

SGD#

class braintools.optim.SGD(lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, grad_clip_norm=None, grad_clip_value=None)#

Stochastic Gradient Descent (SGD) optimizer with momentum and weight decay.

Implements the standard SGD algorithm with optional momentum, Nesterov momentum, and weight decay (L2 regularization).

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance.

  • momentum (float) – Momentum factor. Set to 0 for vanilla SGD.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • nesterov (bool) – Whether to use Nesterov momentum.

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping.

  • grad_clip_value (float | None) – Maximum gradient value for clipping.

Notes

The SGD update with momentum is computed as:

\[ \begin{align}\begin{aligned}v_{t+1} = \mu v_t + g_t\\\theta_{t+1} = \theta_t - \alpha v_{t+1}\end{aligned}\end{align} \]

where \(\mu\) is the momentum, \(g_t\) is the gradient at step t, \(\alpha\) is the learning rate, and \(\theta\) are the parameters.

With Nesterov momentum, the update becomes:

\[ \begin{align}\begin{aligned}v_{t+1} = \mu v_t + g_t\\\theta_{t+1} = \theta_t - \alpha (\mu v_{t+1} + g_t)\end{aligned}\end{align} \]

Examples

Basic SGD without momentum:

>>> import brainstate
>>> import braintools
>>> import jax.numpy as jnp
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>> optimizer = braintools.optim.SGD(lr=0.01)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

SGD with momentum:

>>> optimizer = braintools.optim.SGD(lr=0.01, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

SGD with Nesterov momentum:

>>> optimizer = braintools.optim.SGD(lr=0.01, momentum=0.9, nesterov=True)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

SGD with learning rate scheduling:

>>> scheduler = braintools.optim.StepLR(base_lr=0.1, step_size=30, gamma=0.1)
>>> optimizer = braintools.optim.SGD(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> for epoch in range(100):
...     # Training code here
...     optimizer.step(grads)
...     if (epoch + 1) % epoch_size == 0:
...         scheduler.step()

See also

Adam

Adam optimizer with adaptive learning rates

RMSprop

RMSprop optimizer

Momentum

Pure momentum optimizer

default_tx()[source]#

Create SGD-specific gradient transformation.