SM3#
- class braintools.optim.SM3(lr=1.0, momentum=0.9, eps=1e-08, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients) optimizer.
SM3 is a memory-efficient adaptive optimizer designed for training models with large embedding tables and sparse gradients. It achieves significant memory savings by using a clever factorization of the second moment matrix, storing only one value per dimension instead of one value per parameter.
SM3 is particularly effective for:
Models with large embedding layers (e.g., recommendation systems, NLP)
Sparse gradient scenarios (word embeddings, sparse features)
Memory-constrained training environments
Models where most parameters are in embedding tables
The key insight is that for parameters with sparse gradients, we can store a much smaller second moment estimate by exploiting the structure of the parameter tensor.
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float or LRScheduler instance. If float is provided, it will be automatically converted to a ConstantLR scheduler. Note: SM3 typically uses larger base learning rates than Adam (e.g., 1.0).momentum (
float) – Momentum coefficient for the first moment. When > 0, maintains an exponential moving average of gradients. Set to 0 to disable momentum.eps (
float) – Term added to the denominator for numerical stability. Prevents division by zero when second moment estimates are very small.weight_decay (
float) – Weight decay coefficient (L2 penalty). When greater than 0, applies L2 regularization to the parameters.grad_clip_norm (
float|None) – Maximum gradient norm for gradient clipping. If None, no gradient norm clipping is applied.grad_clip_value (
float|None) – Maximum absolute gradient value for element-wise gradient clipping. If None, no gradient value clipping is applied.
Notes
The SM3 update rules are:
For a parameter tensor \(\theta\) of shape \((d_1, d_2, ..., d_k)\):
\[ \begin{align}\begin{aligned}V_t^{(i)} = \max(V_{t-1}^{(i)}, G_t^2) \quad \text{for each dimension } i\\v_t = \sqrt{\min_i V_t^{(i)} + \epsilon}\\M_t = \rho M_{t-1} + (1 - \rho) G_t \quad \text{(if momentum > 0)}\\\theta_{t+1} = \theta_t - \alpha \frac{M_t}{v_t}\end{aligned}\end{align} \]where:
\(G_t\) is the gradient at step t
\(V_t^{(i)}\) is the second moment accumulator for dimension i
\(v_t\) is the effective second moment (min of all dimension accumulators)
\(M_t\) is the momentum (optional)
\(\rho\) is the momentum coefficient
\(\alpha\) is the learning rate
Memory comparison for parameter shape (n, m):
Adam: Stores 2nm values (first + second moment)
SM3: Stores n + m values (one per dimension)
Savings: For large embeddings (e.g., 100k × 512), ~99.5% reduction
Key advantages of SM3:
Extreme memory efficiency: O(sum of dimensions) vs O(product of dimensions)
Sparse gradient friendly: Designed for sparse updates
Adaptive learning rates: Maintains per-parameter adaptation
Simple and stable: No complex hyperparameter tuning needed
Embedding-optimized: Ideal for large embedding layers
SM3 is particularly well-suited for:
Training models with large vocabulary embeddings (NLP, RecSys)
Sparse gradient scenarios (word2vec, matrix factorization)
Memory-constrained environments (edge devices, limited GPU memory)
Recommendation systems with large item/user embeddings
Comparison with other optimizers:
vs Adam: Much less memory, competitive performance on sparse tasks
vs Adagrad: Similar memory, better performance with momentum
vs SGD: Adaptive rates help with sparse features
vs Adafactor: Different factorization, better for embeddings
References
Examples
Basic usage with default settings:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> # Create model with embedding layer >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Initialize SM3 with default lr=1.0 >>> optimizer = braintools.optim.SM3(lr=1.0) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With custom learning rate:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Use smaller learning rate >>> optimizer = braintools.optim.SM3(lr=0.1) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With momentum for better convergence:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Higher momentum for smoother updates >>> optimizer = braintools.optim.SM3(lr=1.0, momentum=0.95) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Without momentum (pure adaptive):
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Disable momentum >>> optimizer = braintools.optim.SM3(lr=1.0, momentum=0.0) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With learning rate scheduler:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Exponential decay schedule >>> scheduler = braintools.optim.StepLR( ... base_lr=1.0, ... step_size=100, ... gamma=0.9 ... ) >>> optimizer = braintools.optim.SM3(lr=scheduler, momentum=0.9) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete configuration for large embedding model:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> # Large embedding model (e.g., 100k vocabulary, 512 dimensions) >>> model = brainstate.nn.Linear(1000, 500) >>> >>> # Learning rate schedule for long training >>> scheduler = braintools.optim.StepLR( ... base_lr=1.0, ... step_size=1000, ... gamma=0.95 ... ) >>> >>> # Complete SM3 configuration >>> optimizer = braintools.optim.SM3( ... lr=scheduler, ... momentum=0.9, ... eps=1e-8, ... weight_decay=0.0001 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
See also