braintools.optim module#
Optimization Algorithms and Learning Rate Schedulers.
This module provides a comprehensive collection of optimization algorithms and learning rate schedulers for training neural networks and spiking neural networks. It includes modern deep learning optimizers (Adam, SGD, etc.), specialized optimizers for scientific computing (SciPy, Nevergrad), and flexible learning rate scheduling strategies.
Key Features:
Gradient-Based Optimizers: Adam, SGD, RMSprop, Adagrad, and variants
Advanced Optimizers: AdamW, RAdam, Lamb, Lion, AdaBelief, etc.
SciPy Integration: Gradient-free and constrained optimization
Nevergrad Integration: Black-box optimization with evolutionary strategies
Learning Rate Schedulers: Step, exponential, cosine, warmup, and custom schedules
PyTorch-like Interface: Familiar API for PyTorch users
JAX/Optax Backend: High-performance optimization with automatic differentiation
Quick Start - Basic Optimization:
import brainstate as bst
from braintools.optim import Adam
# Define a simple model
class SimpleModel(bst.Module):
def __init__(self):
super().__init__()
self.w = bst.ParamState(jnp.zeros((10, 5)))
self.b = bst.ParamState(jnp.zeros(5))
def __call__(self, x):
return jnp.dot(x, self.w.value) + self.b.value
# Create model and optimizer
model = SimpleModel()
optimizer = Adam(lr=0.001)
# Register trainable parameters
optimizer.register_trainable_weights(model.states(bst.ParamState))
# Training step
@bst.transform.grad(model.states(bst.ParamState), return_value=True)
def loss_fn(data, target):
pred = model(data)
return jnp.mean((pred - target) ** 2)
# Update step
grads, loss = loss_fn(data, target)
optimizer.update(grads)
Quick Start - With Learning Rate Scheduler:
from braintools.optim import Adam, CosineAnnealingLR
# Create optimizer with cosine annealing schedule
scheduler = CosineAnnealingLR(T_max=1000, eta_min=1e-6)
optimizer = Adam(lr=scheduler, weight_decay=1e-4)
optimizer.register_trainable_weights(model.states(bst.ParamState))
# Training loop
for epoch in range(100):
grads, loss = loss_fn(data, target)
optimizer.update(grads)
# Scheduler step is handled automatically
Gradient-Based Optimizers:
from braintools.optim import (
SGD, Momentum, Adam, AdamW, RMSprop,
Adagrad, Adadelta, Nadam, RAdam
)
# Stochastic Gradient Descent
sgd = SGD(lr=0.01, weight_decay=1e-4)
# Momentum
momentum = Momentum(lr=0.01, momentum=0.9, nesterov=True)
# Adam (most popular)
adam = Adam(lr=0.001, betas=(0.9, 0.999), eps=1e-8)
# AdamW (Adam with decoupled weight decay)
adamw = AdamW(lr=0.001, weight_decay=0.01)
# RMSprop
rmsprop = RMSprop(lr=0.001, alpha=0.99, eps=1e-8)
# Adagrad (adaptive learning rates)
adagrad = Adagrad(lr=0.01, eps=1e-10)
# Adadelta (extension of Adagrad)
adadelta = Adadelta(lr=1.0, rho=0.9, eps=1e-6)
# Nadam (Adam + Nesterov momentum)
nadam = Nadam(lr=0.001, betas=(0.9, 0.999))
# RAdam (rectified Adam)
radam = RAdam(lr=0.001, betas=(0.9, 0.999))
Advanced Optimizers:
from braintools.optim import (
Lamb, Lars, Lion, AdaBelief,
Adafactor, Yogi, Lookahead
)
# Lamb (for large batch training)
lamb = Lamb(lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)
# Lars (layer-wise adaptive rate scaling)
lars = Lars(lr=0.01, momentum=0.9, weight_decay=1e-4)
# Lion (evolved sign momentum)
lion = Lion(lr=0.0001, betas=(0.9, 0.99), weight_decay=0.01)
# AdaBelief (adapting stepsizes by belief in gradient direction)
adabelief = AdaBelief(lr=0.001, betas=(0.9, 0.999), eps=1e-16)
# Adafactor (memory-efficient adaptive learning rates)
adafactor = Adafactor(lr=0.001, min_dim_size_to_factor=128)
# Yogi (adaptive learning rate with controlled increases)
yogi = Yogi(lr=0.01, betas=(0.9, 0.999))
# Lookahead (wrapper for other optimizers)
lookahead = Lookahead(
base_optimizer=Adam(lr=0.001),
sync_period=5,
slow_step_size=0.5
)
Learning Rate Schedulers:
from braintools.optim import (
StepLR, MultiStepLR, ExponentialLR,
CosineAnnealingLR, PolynomialLR,
WarmupScheduler, OneCycleLR, CyclicLR,
WarmupCosineSchedule
)
# Step decay
step_lr = StepLR(initial_lr=0.1, step_size=30, gamma=0.1)
# Multi-step decay
multistep_lr = MultiStepLR(initial_lr=0.1, milestones=[30, 60, 90], gamma=0.1)
# Exponential decay
exp_lr = ExponentialLR(initial_lr=0.1, gamma=0.95)
# Cosine annealing
cosine_lr = CosineAnnealingLR(initial_lr=0.1, T_max=100, eta_min=1e-6)
# Polynomial decay
poly_lr = PolynomialLR(initial_lr=0.1, total_steps=1000, power=2.0)
# Warmup then constant
warmup_lr = WarmupScheduler(
warmup_steps=1000,
peak_lr=0.001,
init_lr=1e-6
)
# One-cycle policy
onecycle_lr = OneCycleLR(
max_lr=0.01,
total_steps=1000,
pct_start=0.3,
div_factor=25.0
)
# Cyclic learning rate
cyclic_lr = CyclicLR(
base_lr=0.001,
max_lr=0.01,
step_size_up=2000,
mode='triangular'
)
# Warmup + cosine schedule
warmup_cosine = WarmupCosineSchedule(
warmup_steps=1000,
total_steps=10000,
peak_lr=0.001,
end_lr=1e-6
)
SciPy Optimization:
from braintools.optim import ScipyOptimizer
# Use SciPy's BFGS for gradient-based optimization
scipy_opt = ScipyOptimizer(
method='BFGS',
options={'maxiter': 1000, 'gtol': 1e-6}
)
# Use Nelder-Mead for gradient-free optimization
nelder_mead = ScipyOptimizer(
method='Nelder-Mead',
options={'maxiter': 5000, 'xatol': 1e-8}
)
# Constrained optimization with bounds
constrained = ScipyOptimizer(
method='L-BFGS-B',
bounds=[(0, 1), (-10, 10)],
options={'maxiter': 1000}
)
Nevergrad Optimization:
from braintools.optim import NevergradOptimizer
# Differential evolution
ng_de = NevergradOptimizer(
optimizer='TwoPointsDE',
budget=1000,
num_workers=4
)
# CMA-ES (Covariance Matrix Adaptation)
ng_cma = NevergradOptimizer(
optimizer='CMA',
budget=2000,
num_workers=1
)
# Particle swarm optimization
ng_pso = NevergradOptimizer(
optimizer='PSO',
budget=1000,
num_workers=8
)
Gradient Clipping:
from braintools.optim import Adam
# Clip by global norm
optimizer = Adam(lr=0.001, grad_clip_norm=1.0)
# Clip by value
optimizer = Adam(lr=0.001, grad_clip_value=0.5)
Weight Decay:
from braintools.optim import SGD, AdamW
# L2 regularization (coupled with gradients)
sgd = SGD(lr=0.01, weight_decay=1e-4)
# Decoupled weight decay (better for Adam-like optimizers)
adamw = AdamW(lr=0.001, weight_decay=0.01)
Advanced Scheduler Patterns:
from braintools.optim import (
ChainedScheduler, SequentialLR,
ReduceLROnPlateau, PiecewiseConstantSchedule
)
# Chain multiple schedulers
scheduler = ChainedScheduler([
WarmupScheduler(warmup_steps=1000, peak_lr=0.001),
CosineAnnealingLR(initial_lr=0.001, T_max=9000)
])
# Sequential schedulers (switch at milestones)
sequential = SequentialLR(
schedulers=[
ConstantLR(0.001),
ExponentialLR(initial_lr=0.001, gamma=0.95)
],
milestones=[5000]
)
# Reduce on plateau (requires manual metric tracking)
reduce_plateau = ReduceLROnPlateau(
initial_lr=0.01,
factor=0.5,
patience=10,
mode='min'
)
# Piecewise constant
piecewise = PiecewiseConstantSchedule(
boundaries=[1000, 5000, 8000],
values=[0.1, 0.01, 0.001, 0.0001]
)
Comprehensive optimization toolkit for brain modeling, featuring PyTorch-like optimizers, learning rate schedulers, and advanced optimization algorithms from SciPy and Nevergrad.
Overview#
The braintools.optim module provides:
Modern gradient-based optimizers with PyTorch-compatible APIs
Learning rate schedulers for dynamic learning rate adjustment
Black-box optimization via SciPy and Nevergrad wrappers
State management utilities for optimization workflows
Base Classes#
These classes provide the foundational architecture for all optimizers in the module.
The Optimizer class defines the common interface, while OptaxOptimizer serves
as the base for all gradient-based optimizers built on top of the Optax library.
Base Optimizer Class. |
|
Base class for Optax-based optimizers with PyTorch-like interface. |
Gradient-Based Optimizers#
These optimizers use gradient information to update model parameters. They follow a PyTorch-like API, making them familiar to users coming from the PyTorch ecosystem. All gradient-based optimizers support features like weight decay, gradient clipping, and integration with learning rate schedulers.
Standard Optimizers#
These are the most commonly used optimizers in deep learning and neural network training. They provide a good balance between convergence speed and stability for most applications.
Stochastic Gradient Descent (SGD) optimizer with momentum and weight decay. |
|
Momentum optimizer. |
|
Nesterov Momentum optimizer. |
|
Adam (Adaptive Moment Estimation) optimizer. |
|
AdamW optimizer with decoupled weight decay regularization. |
|
Adagrad optimizer with adaptive learning rates. |
|
Adadelta optimizer - an extension of Adagrad. |
|
RMSprop (Root Mean Square Propagation) optimizer. |
|
Adamax optimizer - variant of Adam based on infinity norm. |
|
Nadam optimizer - Adam with Nesterov accelerated gradient. |
Advanced Optimizers#
These optimizers implement state-of-the-art optimization algorithms designed for specific use cases or improved performance. They often provide better convergence properties for large-scale models, handle sparse gradients more effectively, or offer improved stability in challenging optimization landscapes.
RAdam optimizer (Rectified Adam). |
|
LAMB optimizer (Layer-wise Adaptive Moments). |
|
LARS optimizer (Layer-wise Adaptive Rate Scaling). |
|
Lookahead optimizer wrapper. |
|
Yogi optimizer (improvement over Adam). |
|
L-BFGS optimizer (Limited-memory Broyden-Fletcher-Goldfarb-Shanno). |
|
Rprop optimizer (Resilient Backpropagation). |
|
Adafactor optimizer (memory-efficient variant of Adam). |
|
AdaBelief optimizer - Adapts step size according to belief in gradient direction. |
|
Lion (EvoLved Sign Momentum) optimizer - Discovered through program search. |
|
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients) optimizer. |
|
Novograd (Normalized Gradient) optimizer - Layer-wise gradient normalization with momentum. |
|
Fromage (FRee-scale Optimal Metho for Adaptive GradiEnt) optimizer. |
Learning Rate Schedulers#
Learning rate schedulers dynamically adjust the learning rate during training to improve convergence and final model performance. They can help escape local minima, fine-tune models more effectively, and achieve better generalization. All schedulers are compatible with any gradient-based optimizer.
Base Scheduler#
The abstract base class that defines the interface for all learning rate schedulers. Custom schedulers should inherit from this class.
Base class for learning rate schedulers. |
Step-based Schedulers#
These schedulers adjust the learning rate at fixed intervals or following predetermined patterns. They are simple to configure and work well for many standard training scenarios.
Step learning rate scheduler - Decays learning rate by gamma every step_size epochs. |
|
Multi-step learning rate scheduler - Decays learning rate at specific milestone epochs. |
|
Constant learning rate scheduler - Multiplies learning rate by a constant factor. |
|
Linear learning rate scheduler - Linearly scales learning rate between two factors. |
|
Exponential learning rate scheduler - Decays learning rate exponentially. |
|
Polynomial learning rate scheduler - Decays learning rate using polynomial function. |
|
Exponential decay learning rate scheduler with step-based control. |
Annealing Schedulers#
These schedulers smoothly decrease the learning rate following mathematical functions like cosine curves. They often provide better convergence than step-based approaches and are particularly effective for fine-tuning and achieving optimal final performance.
Cosine annealing learning rate scheduler - Smoothly anneals learning rate using cosine function. |
|
Cosine annealing with warm restarts - SGDR (Stochastic Gradient Descent with Warm Restarts). |
|
Warmup + Cosine annealing schedule for smooth training transitions. |
Cyclic Schedulers#
These schedulers vary the learning rate in cycles, allowing the model to escape sharp minima and explore the loss landscape more effectively. They can lead to better generalization and faster convergence.
Cyclic learning rate scheduler - Oscillates learning rate between bounds. |
|
One cycle learning rate scheduler - Super-convergence training policy. |
Adaptive Schedulers#
These schedulers adjust the learning rate based on training dynamics or combine multiple scheduling strategies. They can automatically adapt to the training progress or provide complex scheduling patterns for specialized training regimes.
Reduce learning rate when a metric has stopped improving - Adaptive LR based on validation metrics. |
|
Warmup learning rate scheduler - Linearly increases learning rate during warmup phase. |
|
Piecewise constant learning rate schedule with step-wise transitions. |
Composite Schedulers#
These schedulers combine multiple scheduling strategies, allowing you to chain different schedulers together or switch between them at different training phases. They provide maximum flexibility for complex training scenarios.
Chain multiple schedulers together - Applies multiple schedulers simultaneously. |
|
Sequential learning rate scheduler - Chains multiple schedulers based on epoch milestones. |
Black-Box Optimizers#
These optimizers are designed for derivative-free optimization problems where gradients are not available or are expensive to compute. They are particularly useful for hyperparameter optimization, neural architecture search, and optimizing non-differentiable objectives. These wrappers provide a unified interface to powerful optimization libraries.
SciPy-based optimizer with dict/sequence bounds compatible with Nevergrad. |
|
Ask/tell optimizer wrapper around Nevergrad with batched evaluation support. |
ScipyOptimizer: Wraps SciPy’s optimization algorithms including BFGS, L-BFGS-B, Nelder-Mead, Powell, and other classical optimization methods. Best for low to medium dimensional problems with smooth objectives.
NevergradOptimizer: Integrates Facebook’s Nevergrad library, providing access to evolutionary algorithms, particle swarm optimization, differential evolution, and other population-based methods. Excellent for high-dimensional, noisy, or discrete optimization problems.
Utilities#
Helper classes and functions that support optimization workflows, including state management for complex optimization scenarios.
A class to manage unique State objects in a PyTree structure. |
The UniqueStateManager helps manage unique state objects in PyTree structures,
ensuring proper state isolation and preventing unintended state sharing during
optimization of complex models with nested components.