Adagrad#
- class braintools.optim.Adagrad(lr=0.01, lr_decay=0.0, weight_decay=0.0, initial_accumulator_value=0.0, eps=1e-10, grad_clip_norm=None, grad_clip_value=None)#
Adagrad optimizer with adaptive learning rates.
Adagrad adapts the learning rate for each parameter based on the historical gradient information. Parameters with larger gradients have smaller learning rates, and vice versa. This makes it well-suited for sparse data.
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float or LRScheduler instance.lr_decay (
float) – Learning rate decay over each update.weight_decay (
float) – Weight decay (L2 penalty) coefficient.initial_accumulator_value (
float) – Initial value for the gradient accumulator.eps (
float) – Term added to the denominator to improve numerical stability.grad_clip_norm (
float|None) – Maximum gradient norm for clipping.grad_clip_value (
float|None) – Maximum gradient value for clipping.
Notes
The Adagrad update is computed as:
\[ \begin{align}\begin{aligned}G_t = G_{t-1} + g_t^2\\\theta_t = \theta_{t-1} - \frac{\alpha}{\sqrt{G_t} + \epsilon} g_t\end{aligned}\end{align} \]where \(G_t\) accumulates squared gradients, \(g_t\) is the gradient, \(\alpha\) is the learning rate, and \(\epsilon\) is for numerical stability.
Adagrad’s main weakness is that the accumulated squared gradients in the denominator continue to grow, causing the learning rate to shrink and eventually become infinitesimally small.
References
Examples
Basic Adagrad usage:
>>> import brainstate >>> import braintools >>> >>> model = brainstate.nn.Linear(10, 5) >>> optimizer = braintools.optim.Adagrad(lr=0.01) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Adagrad with custom epsilon for stability:
>>> optimizer = braintools.optim.Adagrad(lr=0.01, eps=1e-8) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Adagrad with weight decay:
>>> optimizer = braintools.optim.Adagrad(lr=0.01, weight_decay=0.01) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
See also