Adafactor#
- class braintools.optim.Adafactor(lr=None, eps=(1e-30, 0.001), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, factored=True, grad_clip_norm=None, grad_clip_value=None)#
Adafactor optimizer (memory-efficient variant of Adam).
Adafactor is designed to reduce memory usage during training large models by using factored second moment estimation. Instead of storing a full second moment matrix, it maintains row and column statistics, significantly reducing memory requirements especially for models with large embedding tables.
- Parameters:
lr (
float|LRScheduler|None) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance. If None, uses adaptive learning rate based on step count and RMS of parameters.eps (
Tuple[float,float]) – Regularization constants for squared gradient and parameter scale (eps[0], eps[1]). The first value prevents division by zero, the second clips the parameter scale.clip_threshold (
float) – Threshold for gradient clipping by root mean square. Helps prevent gradient explosions.decay_rate (
float) – Controls the decay of the second moment estimate. Negative values result in polynomial decay: decay = 1 - (step + 1)^decay_rate.beta1 (
float|None) – Momentum parameter for first moment. If None, no momentum is used.weight_decay (
float) – Weight decay (L2 penalty) coefficient.factored (
bool) – Whether to use factored second moment estimation. When True, significantly reduces memory usage. Set to False to use full second moment (more memory).grad_clip_norm (
float|None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.grad_clip_value (
float|None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].
Notes
Adafactor’s key innovation is factored second moment estimation. Instead of maintaining a full matrix \(V_t \in \mathbb{R}^{n \times m}\), it maintains row and column averages \(R_t \in \mathbb{R}^n\) and \(C_t \in \mathbb{R}^m\):
\[ \begin{align}\begin{aligned}R_t = \beta_2 R_{t-1} + (1 - \beta_2) \text{mean}(G_t^2, \text{axis}=1)\\C_t = \beta_2 C_{t-1} + (1 - \beta_2) \text{mean}(G_t^2, \text{axis}=0)\end{aligned}\end{align} \]The second moment is approximated as:
\[V_t \approx R_t \otimes C_t / \text{mean}(R_t)\]where \(\otimes\) denotes outer product.
The update rule with optional momentum is:
\[ \begin{align}\begin{aligned}M_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t \quad \text{(if beta1 is not None)}\\\theta_{t+1} = \theta_t - \alpha_t \frac{M_t}{\sqrt{V_t} + \epsilon}\end{aligned}\end{align} \]Key advantages of Adafactor:
Memory efficient: O(n+m) instead of O(n×m) for factored mode
Adaptive learning rate: Can work without explicit learning rate
Large models: Designed for transformer and large embedding models
Stable training: Built-in gradient clipping
Automatic scheduling: Polynomial decay of second moment
Adafactor is particularly well-suited for:
Training very large transformer models (BERT, GPT, T5)
Models with large embedding tables
Situations with limited GPU memory
Long training runs where adaptive scheduling helps
References
Examples
Basic usage with automatic learning rate:
>>> import brainstate as brainstate >>> import braintools as braintools >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Initialize Adafactor with auto learning rate >>> optimizer = braintools.optim.Adafactor() >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With explicit learning rate:
>>> # Explicit learning rate >>> optimizer = braintools.optim.Adafactor(lr=1e-3) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With momentum for better convergence:
>>> # Add momentum (first moment) >>> optimizer = braintools.optim.Adafactor( ... lr=1e-3, ... beta1=0.9 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Non-factored mode for better accuracy:
>>> # Use full second moment (more memory) >>> optimizer = braintools.optim.Adafactor( ... lr=1e-3, ... factored=False ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
For large transformer training:
>>> import jax.numpy as jnp >>> >>> # Setup large transformer model >>> model = brainstate.nn.Sequential( ... brainstate.nn.Embedding(50000, 512), # Large vocabulary ... brainstate.nn.Linear(512, 512), ... brainstate.nn.ReLU(), ... brainstate.nn.Linear(512, 50000) ... ) >>> >>> # Adafactor with factored mode for memory efficiency >>> optimizer = braintools.optim.Adafactor( ... lr=None, # Adaptive learning rate ... beta1=0.9, ... factored=True, ... clip_threshold=1.0 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training step >>> def train_step(tokens, targets): ... def loss_fn(): ... logits = model(tokens) ... return jnp.mean( ... braintools.metric.softmax_cross_entropy(logits, targets) ... ) ... ... grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))() ... optimizer.update(grads) ... return loss_fn() >>> >>> # Train with large batches >>> tokens = jnp.ones((128, 512), dtype=jnp.int32) >>> targets = jnp.zeros((128, 50000)) >>> # loss = train_step(tokens, targets)
With weight decay and custom parameters:
>>> # Complete configuration >>> optimizer = braintools.optim.Adafactor( ... lr=1e-3, ... eps=(1e-30, 1e-3), ... clip_threshold=1.0, ... decay_rate=-0.8, ... beta1=0.9, ... weight_decay=0.01, ... factored=True ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
See also