Adafactor

Adafactor#

class braintools.optim.Adafactor(lr=None, eps=(1e-30, 0.001), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.0, factored=True, grad_clip_norm=None, grad_clip_value=None)#

Adafactor optimizer (memory-efficient variant of Adam).

Adafactor is designed to reduce memory usage during training large models by using factored second moment estimation. Instead of storing a full second moment matrix, it maintains row and column statistics, significantly reducing memory requirements especially for models with large embedding tables.

Parameters:
  • lr (float | LRScheduler | None) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance. If None, uses adaptive learning rate based on step count and RMS of parameters.

  • eps (Tuple[float, float]) – Regularization constants for squared gradient and parameter scale (eps[0], eps[1]). The first value prevents division by zero, the second clips the parameter scale.

  • clip_threshold (float) – Threshold for gradient clipping by root mean square. Helps prevent gradient explosions.

  • decay_rate (float) – Controls the decay of the second moment estimate. Negative values result in polynomial decay: decay = 1 - (step + 1)^decay_rate.

  • beta1 (float | None) – Momentum parameter for first moment. If None, no momentum is used.

  • weight_decay (float) – Weight decay (L2 penalty) coefficient.

  • factored (bool) – Whether to use factored second moment estimation. When True, significantly reduces memory usage. Set to False to use full second moment (more memory).

  • grad_clip_norm (float | None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.

  • grad_clip_value (float | None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].

Notes

Adafactor’s key innovation is factored second moment estimation. Instead of maintaining a full matrix \(V_t \in \mathbb{R}^{n \times m}\), it maintains row and column averages \(R_t \in \mathbb{R}^n\) and \(C_t \in \mathbb{R}^m\):

\[ \begin{align}\begin{aligned}R_t = \beta_2 R_{t-1} + (1 - \beta_2) \text{mean}(G_t^2, \text{axis}=1)\\C_t = \beta_2 C_{t-1} + (1 - \beta_2) \text{mean}(G_t^2, \text{axis}=0)\end{aligned}\end{align} \]

The second moment is approximated as:

\[V_t \approx R_t \otimes C_t / \text{mean}(R_t)\]

where \(\otimes\) denotes outer product.

The update rule with optional momentum is:

\[ \begin{align}\begin{aligned}M_t = \beta_1 M_{t-1} + (1 - \beta_1) G_t \quad \text{(if beta1 is not None)}\\\theta_{t+1} = \theta_t - \alpha_t \frac{M_t}{\sqrt{V_t} + \epsilon}\end{aligned}\end{align} \]

Key advantages of Adafactor:

  • Memory efficient: O(n+m) instead of O(n×m) for factored mode

  • Adaptive learning rate: Can work without explicit learning rate

  • Large models: Designed for transformer and large embedding models

  • Stable training: Built-in gradient clipping

  • Automatic scheduling: Polynomial decay of second moment

Adafactor is particularly well-suited for:

  • Training very large transformer models (BERT, GPT, T5)

  • Models with large embedding tables

  • Situations with limited GPU memory

  • Long training runs where adaptive scheduling helps

References

Examples

Basic usage with automatic learning rate:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize Adafactor with auto learning rate
>>> optimizer = braintools.optim.Adafactor()
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With explicit learning rate:

>>> # Explicit learning rate
>>> optimizer = braintools.optim.Adafactor(lr=1e-3)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With momentum for better convergence:

>>> # Add momentum (first moment)
>>> optimizer = braintools.optim.Adafactor(
...     lr=1e-3,
...     beta1=0.9
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Non-factored mode for better accuracy:

>>> # Use full second moment (more memory)
>>> optimizer = braintools.optim.Adafactor(
...     lr=1e-3,
...     factored=False
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

For large transformer training:

>>> import jax.numpy as jnp
>>>
>>> # Setup large transformer model
>>> model = brainstate.nn.Sequential(
...     brainstate.nn.Embedding(50000, 512),  # Large vocabulary
...     brainstate.nn.Linear(512, 512),
...     brainstate.nn.ReLU(),
...     brainstate.nn.Linear(512, 50000)
... )
>>>
>>> # Adafactor with factored mode for memory efficiency
>>> optimizer = braintools.optim.Adafactor(
...     lr=None,  # Adaptive learning rate
...     beta1=0.9,
...     factored=True,
...     clip_threshold=1.0
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training step
>>> def train_step(tokens, targets):
...     def loss_fn():
...         logits = model(tokens)
...         return jnp.mean(
...             braintools.metric.softmax_cross_entropy(logits, targets)
...         )
...
...     grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
...     optimizer.update(grads)
...     return loss_fn()
>>>
>>> # Train with large batches
>>> tokens = jnp.ones((128, 512), dtype=jnp.int32)
>>> targets = jnp.zeros((128, 50000))
>>> # loss = train_step(tokens, targets)

With weight decay and custom parameters:

>>> # Complete configuration
>>> optimizer = braintools.optim.Adafactor(
...     lr=1e-3,
...     eps=(1e-30, 1e-3),
...     clip_threshold=1.0,
...     decay_rate=-0.8,
...     beta1=0.9,
...     weight_decay=0.01,
...     factored=True
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

See also

Adam

Standard adaptive moment estimation

AdamW

Adam with decoupled weight decay

SM3

Another memory-efficient adaptive optimizer

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.