OptaxOptimizer#

class braintools.optim.OptaxOptimizer(tx=None, lr=0.001, weight_decay=0.0, grad_clip_norm=None, grad_clip_value=None)#

Base class for Optax-based optimizers with PyTorch-like interface.

This class provides a unified interface for JAX/Optax optimizers with advanced features including learning rate scheduling, parameter groups, gradient clipping, and weight decay.

Parameters:
  • tx (GradientTransformation | None) – An Optax gradient transformation. If None, will be created based on optimizer-specific parameters via _create_default_tx() method.

  • lr (float | LRScheduler) – Learning rate. Can be a float (automatically converted to ConstantLR) or an LRScheduler instance for advanced scheduling.

  • weight_decay (float) – Weight decay coefficient (L2 penalty).

  • grad_clip_norm (float | None) – Maximum gradient norm for clipping. If provided, gradients will be clipped using optax.clip_by_global_norm.

  • grad_clip_value (float | None) – Maximum gradient value for clipping. If provided, gradients will be clipped element-wise using optax.clip.

param_states#

Container for PyTree of brainstate.State objects representing trainable parameters.

Type:

UniqueStateManager

opt_state#

Optimizer state containing momentum, variance, and other optimizer-specific values.

Type:

OptimState

step_count#

Number of optimization steps taken.

Type:

OptimState

param_groups#

List of parameter groups with their own hyperparameters. Each group is a dictionary with keys ‘params’, ‘lr’, ‘weight_decay’, etc.

Type:

list of dict

base_lr#

Base learning rate (read-only property).

Type:

float

lr#

Current learning rate (can be modified by schedulers).

Type:

float

register_trainable_weights(param_states)[source]#

Register parameters to be optimized.

add_param_group(params, \*\*kwargs)[source]#

Add a new parameter group with custom hyperparameters.

step(grads, closure=None)[source]#

Perform a single optimization step.

state_dict()[source]#

Get optimizer state for checkpointing.

load_state_dict(state_dict)[source]#

Load optimizer state from checkpoint.

Notes

This base class implements the unified learning rate handling where all learning rates (whether float or LRScheduler) are internally managed through a scheduler. Float learning rates are automatically converted to ConstantLR for consistent handling.

Examples

Basic usage with float learning rate:

>>> import brainstate
>>> import braintools
>>>
>>> # Define a simple model
>>> class Model(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.linear = brainstate.nn.Linear(10, 5)
...
>>> model = Model()
>>> optimizer = braintools.optim.Adam(lr=0.001)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Using learning rate scheduler:

>>> import brainstate
>>> import braintools
>>>
>>> model = Model()
>>> scheduler = braintools.optim.StepLR(base_lr=0.01, step_size=10, gamma=0.1)
>>> optimizer = braintools.optim.Adam(lr=scheduler)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Training loop
>>> for epoch in range(100):
...     # ... compute gradients
...     optimizer.step(grads)
...     scheduler.step()

Multiple parameter groups with different learning rates:

>>> import brainstate
>>> import braintools
>>>
>>> model = Model()
>>> optimizer = braintools.optim.Adam(lr=0.001)
>>>
>>> # Register main parameters
>>> optimizer.register_trainable_weights(model.linear.states(brainstate.ParamState))
>>>
>>> # Add another group with different lr
>>> special_params = {'special': brainstate.ParamState(jnp.zeros(5))}
>>> optimizer.add_param_group(special_params, lr=0.0001)

See also

Adam

Adam optimizer with adaptive learning rates

SGD

Stochastic gradient descent optimizer

AdamW

AdamW optimizer with decoupled weight decay

StepLR

Learning rate scheduler with step decay

add_param_group(params, **kwargs)[source]#

Add a parameter group with specific hyperparameters.

Parameters:
  • params (State]) – A pytree (dict) of brainstate.State objects.

  • **kwargs – Additional hyperparameters for this group.

add_scheduler(scheduler)[source]#

Add a learning rate scheduler.

property base_lr: float#

Get base learning rate.

property current_lr#

Get current learning rate.

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.

get_last_lr()[source]#

Get last computed learning rates from schedulers.

Return type:

List[float]

load_state_dict(state_dict)[source]#

Load optimizer state from a dictionary.

Parameters:

state_dict (Dict[str, Any]) – Dictionary containing optimizer state.

lr_apply(apply_fn)[source]#

Apply a function to modify the current learning rate.

register_trainable_weights(param_states)[source]#

Register trainable weights and initialize optimizer state.

Parameters:

param_states (State]) – A pytree (dict) of brainstate.State objects representing parameters.

state_dict()[source]#

Return the state of the optimizer as a dictionary.

Returns:

Dictionary containing optimizer state, step count, and hyperparameters.

step(grads)[source]#

Perform a single optimization step.

Parameters:

grads (Dict[str, Any] | None) – Gradients for parameters. If None, closure must be provided.

Returns:

Optional loss value if closure is provided.

update(grads)[source]#

Update the model states with gradients (backward compatibility).