RAdam#
- class braintools.optim.RAdam(lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#
RAdam optimizer (Rectified Adam).
RAdam addresses the bad convergence problem of Adam by rectifying the variance of the adaptive learning rate. It provides a dynamic warmup schedule that automatically adapts to the current optimization state.
- Parameters:
lr (
float|LRScheduler) – Learning rate. Can be a float (converted to ConstantLR) or any LRScheduler instance.betas (
Tuple[float,float]) – Coefficients for computing running averages of gradient and its square. The first value (beta1) controls the exponential decay rate for the first moment, and the second value (beta2) controls the decay rate for the second moment.eps (
float) – Term added to the denominator for numerical stability.weight_decay (
float) – Weight decay (L2 penalty) coefficient.grad_clip_norm (
float|None) – Maximum norm for gradient clipping. If specified, gradients are clipped when their global norm exceeds this value.grad_clip_value (
float|None) – Maximum absolute value for gradient clipping. If specified, gradients are clipped element-wise to [-grad_clip_value, grad_clip_value].
Notes
RAdam introduces a rectification term that explicitly controls the adaptive learning rate based on the variance of the exponential moving average. The update rule is:
\[ \begin{align}\begin{aligned}\rho_t = \rho_{\infty} - \frac{2t\beta_2^t}{1-\beta_2^t}\\r_t = \sqrt{\frac{(\rho_t - 4)(\rho_t - 2)\rho_{\infty}}{(\rho_{\infty} - 4)(\rho_{\infty} - 2)\rho_t}}\end{aligned}\end{align} \]where \(\rho_{\infty} = \frac{2}{1-\beta_2} - 1\)
When \(\rho_t > 4\), the adaptive learning rate with rectification is used:
\[\theta_{t+1} = \theta_t - \alpha \cdot r_t \cdot \frac{m_t}{\sqrt{v_t} + \epsilon}\]Otherwise, it falls back to non-adaptive learning rate:
\[\theta_{t+1} = \theta_t - \alpha \cdot m_t\]RAdam automatically performs warmup without requiring manual tuning, making it more robust than standard Adam in the early stages of training.
References
Examples
Basic usage with float learning rate:
>>> import brainstate >>> import braintools >>> >>> # Create model >>> model = brainstate.nn.Linear(10, 5) >>> >>> # Initialize RAdam optimizer >>> optimizer = braintools.optim.RAdam(lr=0.001) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Using custom beta values for different convergence behavior:
>>> # Slower first moment decay for more stable updates >>> optimizer = braintools.optim.RAdam(lr=0.001, betas=(0.8, 0.999)) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With learning rate scheduler for training dynamics:
>>> # Exponential learning rate decay >>> scheduler = braintools.optim.ExponentialLR(base_lr=0.001, gamma=0.95) >>> optimizer = braintools.optim.RAdam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Using gradient clipping for stability:
>>> # Clip gradients by global norm >>> optimizer = braintools.optim.RAdam( ... lr=0.001, ... grad_clip_norm=1.0 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
With weight decay for regularization:
>>> # Add L2 regularization >>> optimizer = braintools.optim.RAdam( ... lr=0.001, ... weight_decay=0.01 ... ) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Complete training loop example:
>>> import jax.numpy as jnp >>> >>> # Setup >>> model = brainstate.nn.Linear(10, 5) >>> optimizer = braintools.optim.RAdam(lr=0.001) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training step >>> @brainstate.transform.jit ... def train_step(input_data, target): ... def loss_fn(): ... pred = model(input_data) ... return jnp.mean((pred - target) ** 2) ... ... grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))() ... optimizer.update(grads) ... return loss_fn() >>> >>> # Train >>> x = jnp.ones((32, 10)) >>> y = jnp.zeros((32, 5)) >>> loss = train_step(x, y)
See also