braintools.optim module#

Optimization Algorithms and Learning Rate Schedulers.

This module provides a comprehensive collection of optimization algorithms and learning rate schedulers for training neural networks and spiking neural networks. It includes modern deep learning optimizers (Adam, SGD, etc.), specialized optimizers for scientific computing (SciPy, Nevergrad), and flexible learning rate scheduling strategies.

Key Features:

  • Gradient-Based Optimizers: Adam, SGD, RMSprop, Adagrad, and variants

  • Advanced Optimizers: AdamW, RAdam, Lamb, Lion, AdaBelief, etc.

  • SciPy Integration: Gradient-free and constrained optimization

  • Nevergrad Integration: Black-box optimization with evolutionary strategies

  • Learning Rate Schedulers: Step, exponential, cosine, warmup, and custom schedules

  • PyTorch-like Interface: Familiar API for PyTorch users

  • JAX/Optax Backend: High-performance optimization with automatic differentiation

Quick Start - Basic Optimization:

import brainstate as bst
from braintools.optim import Adam

# Define a simple model
class SimpleModel(bst.Module):
    def __init__(self):
        super().__init__()
        self.w = bst.ParamState(jnp.zeros((10, 5)))
        self.b = bst.ParamState(jnp.zeros(5))

    def __call__(self, x):
        return jnp.dot(x, self.w.value) + self.b.value

# Create model and optimizer
model = SimpleModel()
optimizer = Adam(lr=0.001)

# Register trainable parameters
optimizer.register_trainable_weights(model.states(bst.ParamState))

# Training step
@bst.transform.grad(model.states(bst.ParamState), return_value=True)
def loss_fn(data, target):
    pred = model(data)
    return jnp.mean((pred - target) ** 2)

# Update step
grads, loss = loss_fn(data, target)
optimizer.update(grads)

Quick Start - With Learning Rate Scheduler:

from braintools.optim import Adam, CosineAnnealingLR

# Create optimizer with cosine annealing schedule
scheduler = CosineAnnealingLR(T_max=1000, eta_min=1e-6)
optimizer = Adam(lr=scheduler, weight_decay=1e-4)

optimizer.register_trainable_weights(model.states(bst.ParamState))

# Training loop
for epoch in range(100):
    grads, loss = loss_fn(data, target)
    optimizer.update(grads)
    # Scheduler step is handled automatically

Gradient-Based Optimizers:

from braintools.optim import (
    SGD, Momentum, Adam, AdamW, RMSprop,
    Adagrad, Adadelta, Nadam, RAdam
)

# Stochastic Gradient Descent
sgd = SGD(lr=0.01, weight_decay=1e-4)

# Momentum
momentum = Momentum(lr=0.01, momentum=0.9, nesterov=True)

# Adam (most popular)
adam = Adam(lr=0.001, betas=(0.9, 0.999), eps=1e-8)

# AdamW (Adam with decoupled weight decay)
adamw = AdamW(lr=0.001, weight_decay=0.01)

# RMSprop
rmsprop = RMSprop(lr=0.001, alpha=0.99, eps=1e-8)

# Adagrad (adaptive learning rates)
adagrad = Adagrad(lr=0.01, eps=1e-10)

# Adadelta (extension of Adagrad)
adadelta = Adadelta(lr=1.0, rho=0.9, eps=1e-6)

# Nadam (Adam + Nesterov momentum)
nadam = Nadam(lr=0.001, betas=(0.9, 0.999))

# RAdam (rectified Adam)
radam = RAdam(lr=0.001, betas=(0.9, 0.999))

Advanced Optimizers:

from braintools.optim import (
    Lamb, Lars, Lion, AdaBelief,
    Adafactor, Yogi, Lookahead
)

# Lamb (for large batch training)
lamb = Lamb(lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)

# Lars (layer-wise adaptive rate scaling)
lars = Lars(lr=0.01, momentum=0.9, weight_decay=1e-4)

# Lion (evolved sign momentum)
lion = Lion(lr=0.0001, betas=(0.9, 0.99), weight_decay=0.01)

# AdaBelief (adapting stepsizes by belief in gradient direction)
adabelief = AdaBelief(lr=0.001, betas=(0.9, 0.999), eps=1e-16)

# Adafactor (memory-efficient adaptive learning rates)
adafactor = Adafactor(lr=0.001, min_dim_size_to_factor=128)

# Yogi (adaptive learning rate with controlled increases)
yogi = Yogi(lr=0.01, betas=(0.9, 0.999))

# Lookahead (wrapper for other optimizers)
lookahead = Lookahead(
    base_optimizer=Adam(lr=0.001),
    sync_period=5,
    slow_step_size=0.5
)

Learning Rate Schedulers:

from braintools.optim import (
    StepLR, MultiStepLR, ExponentialLR,
    CosineAnnealingLR, PolynomialLR,
    WarmupScheduler, OneCycleLR, CyclicLR,
    WarmupCosineSchedule
)

# Step decay
step_lr = StepLR(initial_lr=0.1, step_size=30, gamma=0.1)

# Multi-step decay
multistep_lr = MultiStepLR(initial_lr=0.1, milestones=[30, 60, 90], gamma=0.1)

# Exponential decay
exp_lr = ExponentialLR(initial_lr=0.1, gamma=0.95)

# Cosine annealing
cosine_lr = CosineAnnealingLR(initial_lr=0.1, T_max=100, eta_min=1e-6)

# Polynomial decay
poly_lr = PolynomialLR(initial_lr=0.1, total_steps=1000, power=2.0)

# Warmup then constant
warmup_lr = WarmupScheduler(
    warmup_steps=1000,
    peak_lr=0.001,
    init_lr=1e-6
)

# One-cycle policy
onecycle_lr = OneCycleLR(
    max_lr=0.01,
    total_steps=1000,
    pct_start=0.3,
    div_factor=25.0
)

# Cyclic learning rate
cyclic_lr = CyclicLR(
    base_lr=0.001,
    max_lr=0.01,
    step_size_up=2000,
    mode='triangular'
)

# Warmup + cosine schedule
warmup_cosine = WarmupCosineSchedule(
    warmup_steps=1000,
    total_steps=10000,
    peak_lr=0.001,
    end_lr=1e-6
)

SciPy Optimization:

from braintools.optim import ScipyOptimizer

# Use SciPy's BFGS for gradient-based optimization
scipy_opt = ScipyOptimizer(
    method='BFGS',
    options={'maxiter': 1000, 'gtol': 1e-6}
)

# Use Nelder-Mead for gradient-free optimization
nelder_mead = ScipyOptimizer(
    method='Nelder-Mead',
    options={'maxiter': 5000, 'xatol': 1e-8}
)

# Constrained optimization with bounds
constrained = ScipyOptimizer(
    method='L-BFGS-B',
    bounds=[(0, 1), (-10, 10)],
    options={'maxiter': 1000}
)

Nevergrad Optimization:

from braintools.optim import NevergradOptimizer

# Differential evolution
ng_de = NevergradOptimizer(
    optimizer='TwoPointsDE',
    budget=1000,
    num_workers=4
)

# CMA-ES (Covariance Matrix Adaptation)
ng_cma = NevergradOptimizer(
    optimizer='CMA',
    budget=2000,
    num_workers=1
)

# Particle swarm optimization
ng_pso = NevergradOptimizer(
    optimizer='PSO',
    budget=1000,
    num_workers=8
)

Gradient Clipping:

from braintools.optim import Adam

# Clip by global norm
optimizer = Adam(lr=0.001, grad_clip_norm=1.0)

# Clip by value
optimizer = Adam(lr=0.001, grad_clip_value=0.5)

Weight Decay:

from braintools.optim import SGD, AdamW

# L2 regularization (coupled with gradients)
sgd = SGD(lr=0.01, weight_decay=1e-4)

# Decoupled weight decay (better for Adam-like optimizers)
adamw = AdamW(lr=0.001, weight_decay=0.01)

Advanced Scheduler Patterns:

from braintools.optim import (
    ChainedScheduler, SequentialLR,
    ReduceLROnPlateau, PiecewiseConstantSchedule
)

# Chain multiple schedulers
scheduler = ChainedScheduler([
    WarmupScheduler(warmup_steps=1000, peak_lr=0.001),
    CosineAnnealingLR(initial_lr=0.001, T_max=9000)
])

# Sequential schedulers (switch at milestones)
sequential = SequentialLR(
    schedulers=[
        ConstantLR(0.001),
        ExponentialLR(initial_lr=0.001, gamma=0.95)
    ],
    milestones=[5000]
)

# Reduce on plateau (requires manual metric tracking)
reduce_plateau = ReduceLROnPlateau(
    initial_lr=0.01,
    factor=0.5,
    patience=10,
    mode='min'
)

# Piecewise constant
piecewise = PiecewiseConstantSchedule(
    boundaries=[1000, 5000, 8000],
    values=[0.1, 0.01, 0.001, 0.0001]
)

Comprehensive optimization toolkit for brain modeling, featuring PyTorch-like optimizers, learning rate schedulers, and advanced optimization algorithms from SciPy and Nevergrad.

Overview#

The braintools.optim module provides:

  • Modern gradient-based optimizers with PyTorch-compatible APIs

  • Learning rate schedulers for dynamic learning rate adjustment

  • Black-box optimization via SciPy and Nevergrad wrappers

  • State management utilities for optimization workflows

Base Classes#

These classes provide the foundational architecture for all optimizers in the module. The Optimizer class defines the common interface, while OptaxOptimizer serves as the base for all gradient-based optimizers built on top of the Optax library.

Optimizer

Base Optimizer Class.

OptaxOptimizer

Base class for Optax-based optimizers with PyTorch-like interface.

Gradient-Based Optimizers#

These optimizers use gradient information to update model parameters. They follow a PyTorch-like API, making them familiar to users coming from the PyTorch ecosystem. All gradient-based optimizers support features like weight decay, gradient clipping, and integration with learning rate schedulers.

Standard Optimizers#

These are the most commonly used optimizers in deep learning and neural network training. They provide a good balance between convergence speed and stability for most applications.

SGD

Stochastic Gradient Descent (SGD) optimizer with momentum and weight decay.

Momentum

Momentum optimizer.

MomentumNesterov

Nesterov Momentum optimizer.

Adam

Adam (Adaptive Moment Estimation) optimizer.

AdamW

AdamW optimizer with decoupled weight decay regularization.

Adagrad

Adagrad optimizer with adaptive learning rates.

Adadelta

Adadelta optimizer - an extension of Adagrad.

RMSprop

RMSprop (Root Mean Square Propagation) optimizer.

Adamax

Adamax optimizer - variant of Adam based on infinity norm.

Nadam

Nadam optimizer - Adam with Nesterov accelerated gradient.

Advanced Optimizers#

These optimizers implement state-of-the-art optimization algorithms designed for specific use cases or improved performance. They often provide better convergence properties for large-scale models, handle sparse gradients more effectively, or offer improved stability in challenging optimization landscapes.

RAdam

RAdam optimizer (Rectified Adam).

Lamb

LAMB optimizer (Layer-wise Adaptive Moments).

Lars

LARS optimizer (Layer-wise Adaptive Rate Scaling).

Lookahead

Lookahead optimizer wrapper.

Yogi

Yogi optimizer (improvement over Adam).

LBFGS

L-BFGS optimizer (Limited-memory Broyden-Fletcher-Goldfarb-Shanno).

Rprop

Rprop optimizer (Resilient Backpropagation).

Adafactor

Adafactor optimizer (memory-efficient variant of Adam).

AdaBelief

AdaBelief optimizer - Adapts step size according to belief in gradient direction.

Lion

Lion (EvoLved Sign Momentum) optimizer - Discovered through program search.

SM3

SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients) optimizer.

Novograd

Novograd (Normalized Gradient) optimizer - Layer-wise gradient normalization with momentum.

Fromage

Fromage (FRee-scale Optimal Metho for Adaptive GradiEnt) optimizer.

Learning Rate Schedulers#

Learning rate schedulers dynamically adjust the learning rate during training to improve convergence and final model performance. They can help escape local minima, fine-tune models more effectively, and achieve better generalization. All schedulers are compatible with any gradient-based optimizer.

Base Scheduler#

The abstract base class that defines the interface for all learning rate schedulers. Custom schedulers should inherit from this class.

LRScheduler

Base class for learning rate schedulers.

Step-based Schedulers#

These schedulers adjust the learning rate at fixed intervals or following predetermined patterns. They are simple to configure and work well for many standard training scenarios.

StepLR

Step learning rate scheduler - Decays learning rate by gamma every step_size epochs.

MultiStepLR

Multi-step learning rate scheduler - Decays learning rate at specific milestone epochs.

ConstantLR

Constant learning rate scheduler - Multiplies learning rate by a constant factor.

LinearLR

Linear learning rate scheduler - Linearly scales learning rate between two factors.

ExponentialLR

Exponential learning rate scheduler - Decays learning rate exponentially.

PolynomialLR

Polynomial learning rate scheduler - Decays learning rate using polynomial function.

ExponentialDecayLR

Exponential decay learning rate scheduler with step-based control.

Annealing Schedulers#

These schedulers smoothly decrease the learning rate following mathematical functions like cosine curves. They often provide better convergence than step-based approaches and are particularly effective for fine-tuning and achieving optimal final performance.

CosineAnnealingLR

Cosine annealing learning rate scheduler - Smoothly anneals learning rate using cosine function.

CosineAnnealingWarmRestarts

Cosine annealing with warm restarts - SGDR (Stochastic Gradient Descent with Warm Restarts).

WarmupCosineSchedule

Warmup + Cosine annealing schedule for smooth training transitions.

Cyclic Schedulers#

These schedulers vary the learning rate in cycles, allowing the model to escape sharp minima and explore the loss landscape more effectively. They can lead to better generalization and faster convergence.

CyclicLR

Cyclic learning rate scheduler - Oscillates learning rate between bounds.

OneCycleLR

One cycle learning rate scheduler - Super-convergence training policy.

Adaptive Schedulers#

These schedulers adjust the learning rate based on training dynamics or combine multiple scheduling strategies. They can automatically adapt to the training progress or provide complex scheduling patterns for specialized training regimes.

ReduceLROnPlateau

Reduce learning rate when a metric has stopped improving - Adaptive LR based on validation metrics.

WarmupScheduler

Warmup learning rate scheduler - Linearly increases learning rate during warmup phase.

PiecewiseConstantSchedule

Piecewise constant learning rate schedule with step-wise transitions.

Composite Schedulers#

These schedulers combine multiple scheduling strategies, allowing you to chain different schedulers together or switch between them at different training phases. They provide maximum flexibility for complex training scenarios.

ChainedScheduler

Chain multiple schedulers together - Applies multiple schedulers simultaneously.

SequentialLR

Sequential learning rate scheduler - Chains multiple schedulers based on epoch milestones.

Black-Box Optimizers#

These optimizers are designed for derivative-free optimization problems where gradients are not available or are expensive to compute. They are particularly useful for hyperparameter optimization, neural architecture search, and optimizing non-differentiable objectives. These wrappers provide a unified interface to powerful optimization libraries.

ScipyOptimizer

SciPy-based optimizer with dict/sequence bounds compatible with Nevergrad.

NevergradOptimizer

Ask/tell optimizer wrapper around Nevergrad with batched evaluation support.

  • ScipyOptimizer: Wraps SciPy’s optimization algorithms including BFGS, L-BFGS-B, Nelder-Mead, Powell, and other classical optimization methods. Best for low to medium dimensional problems with smooth objectives.

  • NevergradOptimizer: Integrates Facebook’s Nevergrad library, providing access to evolutionary algorithms, particle swarm optimization, differential evolution, and other population-based methods. Excellent for high-dimensional, noisy, or discrete optimization problems.

Utilities#

Helper classes and functions that support optimization workflows, including state management for complex optimization scenarios.

UniqueStateManager

A class to manage unique State objects in a PyTree structure.

The UniqueStateManager helps manage unique state objects in PyTree structures, ensuring proper state isolation and preventing unintended state sharing during optimization of complex models with nested components.