OptaxOptimizer#
- class braintools.optim.OptaxOptimizer(tx=None, lr=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#
Base class for Optax-based optimizers with PyTorch-like interface.
This class provides a unified interface for JAX/Optax optimizers with advanced features including learning rate scheduling, parameter groups, gradient clipping, and weight decay.
- Parameters:
tx (
GradientTransformation|None) – An Optax gradient transformation. If None, will be created based on optimizer-specific parameters via_create_default_tx()method.lr (
float|LRScheduler) – Learning rate. Can be a float (automatically converted to ConstantLR) or an LRScheduler instance for advanced scheduling.weight_decay (
float) – Weight decay coefficient (L2 penalty).grad_clip_norm (
float|None) – Maximum gradient norm for clipping. If provided, gradients will be clipped usingoptax.clip_by_global_norm.grad_clip_value (
float|None) – Maximum gradient value for clipping. If provided, gradients will be clipped element-wise usingoptax.clip.
- param_states#
Container for PyTree of brainstate.State objects representing trainable parameters.
- Type:
- opt_state#
Optimizer state containing momentum, variance, and other optimizer-specific values.
- Type:
OptimState
- step_count#
Number of optimization steps taken.
- Type:
OptimState
- param_groups#
List of parameter groups with their own hyperparameters. Each group is a dictionary with keys ‘params’, ‘lr’, ‘weight_decay’, etc.
Notes
This base class implements the unified learning rate handling where all learning rates (whether float or LRScheduler) are internally managed through a scheduler. Float learning rates are automatically converted to
ConstantLRfor consistent handling.Examples
Basic usage with float learning rate:
>>> import brainstate >>> import braintools >>> >>> # Define a simple model >>> class Model(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.linear = brainstate.nn.Linear(10, 5) ... >>> model = Model() >>> optimizer = braintools.optim.Adam(lr=0.001) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
Using learning rate scheduler:
>>> import brainstate >>> import braintools >>> >>> model = Model() >>> scheduler = braintools.optim.StepLR(base_lr=0.01, step_size=10, gamma=0.1) >>> optimizer = braintools.optim.Adam(lr=scheduler) >>> optimizer.register_trainable_weights(model.states(brainstate.ParamState)) >>> >>> # Training loop >>> for epoch in range(100): ... # ... compute gradients ... optimizer.step(grads) ... scheduler.step()
Multiple parameter groups with different learning rates:
>>> import brainstate >>> import braintools >>> >>> model = Model() >>> optimizer = braintools.optim.Adam(lr=0.001) >>> >>> # Register main parameters >>> optimizer.register_trainable_weights(model.linear.states(brainstate.ParamState)) >>> >>> # Add another group with different lr >>> special_params = {'special': brainstate.ParamState(jnp.zeros(5))} >>> optimizer.add_param_group(special_params, lr=0.0001)
See also
- add_param_group(params, **kwargs)[source]#
Add a parameter group with specific hyperparameters.
- Parameters:
params (
State]) – A pytree (dict) of brainstate.State objects.**kwargs – Additional hyperparameters for this group.
- property current_lr#
Get current learning rate.
- register_trainable_weights(param_states)[source]#
Register trainable weights and initialize optimizer state.
- Parameters:
param_states (
State]) – A pytree (dict) of brainstate.State objects representing parameters.
- state_dict()[source]#
Return the state of the optimizer as a dictionary.
- Returns:
Dictionary containing optimizer state, step count, and hyperparameters.