SM3

SM3#

class braintools.optim.SM3(lr=1.0, momentum=0.9, eps=1e-08, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients) optimizer.

SM3 is a memory-efficient adaptive optimizer designed for training models with large embedding tables and sparse gradients. It achieves significant memory savings by using a clever factorization of the second moment matrix, storing only one value per dimension instead of one value per parameter.

SM3 is particularly effective for:

  • Models with large embedding layers (e.g., recommendation systems, NLP)

  • Sparse gradient scenarios (word embeddings, sparse features)

  • Memory-constrained training environments

  • Models where most parameters are in embedding tables

The key insight is that for parameters with sparse gradients, we can store a much smaller second moment estimate by exploiting the structure of the parameter tensor.

Parameters:
  • lr (float | LRScheduler) – Learning rate. Can be a float or LRScheduler instance. If float is provided, it will be automatically converted to a ConstantLR scheduler. Note: SM3 typically uses larger base learning rates than Adam (e.g., 1.0).

  • momentum (float) – Momentum coefficient for the first moment. When > 0, maintains an exponential moving average of gradients. Set to 0 to disable momentum.

  • eps (float) – Term added to the denominator for numerical stability. Prevents division by zero when second moment estimates are very small.

  • weight_decay (float) – Weight decay coefficient (L2 penalty). When greater than 0, applies L2 regularization to the parameters.

  • grad_clip_norm (float | None) – Maximum gradient norm for gradient clipping. If None, no gradient norm clipping is applied.

  • grad_clip_value (float | None) – Maximum absolute gradient value for element-wise gradient clipping. If None, no gradient value clipping is applied.

Notes

The SM3 update rules are:

For a parameter tensor \(\theta\) of shape \((d_1, d_2, ..., d_k)\):

\[ \begin{align}\begin{aligned}V_t^{(i)} = \max(V_{t-1}^{(i)}, G_t^2) \quad \text{for each dimension } i\\v_t = \sqrt{\min_i V_t^{(i)} + \epsilon}\\M_t = \rho M_{t-1} + (1 - \rho) G_t \quad \text{(if momentum > 0)}\\\theta_{t+1} = \theta_t - \alpha \frac{M_t}{v_t}\end{aligned}\end{align} \]

where:

  • \(G_t\) is the gradient at step t

  • \(V_t^{(i)}\) is the second moment accumulator for dimension i

  • \(v_t\) is the effective second moment (min of all dimension accumulators)

  • \(M_t\) is the momentum (optional)

  • \(\rho\) is the momentum coefficient

  • \(\alpha\) is the learning rate

Memory comparison for parameter shape (n, m):

  • Adam: Stores 2nm values (first + second moment)

  • SM3: Stores n + m values (one per dimension)

  • Savings: For large embeddings (e.g., 100k × 512), ~99.5% reduction

Key advantages of SM3:

  • Extreme memory efficiency: O(sum of dimensions) vs O(product of dimensions)

  • Sparse gradient friendly: Designed for sparse updates

  • Adaptive learning rates: Maintains per-parameter adaptation

  • Simple and stable: No complex hyperparameter tuning needed

  • Embedding-optimized: Ideal for large embedding layers

SM3 is particularly well-suited for:

  • Training models with large vocabulary embeddings (NLP, RecSys)

  • Sparse gradient scenarios (word2vec, matrix factorization)

  • Memory-constrained environments (edge devices, limited GPU memory)

  • Recommendation systems with large item/user embeddings

Comparison with other optimizers:

  • vs Adam: Much less memory, competitive performance on sparse tasks

  • vs Adagrad: Similar memory, better performance with momentum

  • vs SGD: Adaptive rates help with sparse features

  • vs Adafactor: Different factorization, better for embeddings

References

Examples

Basic usage with default settings:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model with embedding layer
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize SM3 with default lr=1.0
>>> optimizer = braintools.optim.SM3(lr=1.0)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With custom learning rate:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Use smaller learning rate
>>> optimizer = braintools.optim.SM3(lr=0.1)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With momentum for better convergence:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Higher momentum for smoother updates
>>> optimizer = braintools.optim.SM3(lr=1.0, momentum=0.95)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Without momentum (pure adaptive):

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Disable momentum
>>> optimizer = braintools.optim.SM3(lr=1.0, momentum=0.0)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With learning rate scheduler:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Exponential decay schedule
>>> scheduler = braintools.optim.StepLR(
...     base_lr=1.0,
...     step_size=100,
...     gamma=0.9
... )
>>> optimizer = braintools.optim.SM3(lr=scheduler, momentum=0.9)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Complete configuration for large embedding model:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Large embedding model (e.g., 100k vocabulary, 512 dimensions)
>>> model = brainstate.nn.Linear(1000, 500)
>>>
>>> # Learning rate schedule for long training
>>> scheduler = braintools.optim.StepLR(
...     base_lr=1.0,
...     step_size=1000,
...     gamma=0.95
... )
>>>
>>> # Complete SM3 configuration
>>> optimizer = braintools.optim.SM3(
...     lr=scheduler,
...     momentum=0.9,
...     eps=1e-8,
...     weight_decay=0.0001
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

Adam

Standard adaptive moment estimation

Adafactor

Another memory-efficient optimizer

Adagrad

Adaptive learning rates for sparse features

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.