LBFGS#

class braintools.optim.LBFGS(lr=1.0, memory_size=10, scale_init_precond=True, linesearch=None, grad_clip_norm=None, grad_clip_value=None)#

L-BFGS optimizer (Limited-memory Broyden-Fletcher-Goldfarb-Shanno).

L-BFGS is a quasi-Newton optimization method that approximates the inverse Hessian matrix using a limited amount of memory. It provides superlinear convergence for smooth, unconstrained optimization problems and is widely used in scientific computing, machine learning, and numerical optimization.

This optimizer is particularly effective for: - Medium to large-scale optimization problems - Smooth, differentiable objective functions - Full-batch or deterministic gradient computations - Scientific computing and parameter estimation - Neural network fine-tuning with small datasets

Parameters:
  • lr (float | LRScheduler) – Learning rate (step size). Can be a float or any LRScheduler instance. L-BFGS typically uses lr=1.0 as it computes optimal step sizes via line search. Adjust only if line search is disabled or for specific needs.

  • memory_size (int) – Number of past gradient-position pairs to store for Hessian approximation. Typical values: 3-20. Larger values give better approximations but use more memory. Trade-off between accuracy and computational cost.

  • scale_init_hess (bool, default=True) – Whether to scale the initial Hessian approximation using gradient information. Improves convergence by adapting to problem scale. Recommended for most cases.

  • grad_clip_norm (float | None) – Maximum norm for gradient clipping. Gradients are scaled when their global norm exceeds this value. Useful for numerical stability.

  • grad_clip_value (float | None) – Maximum absolute value for element-wise gradient clipping. Each gradient component is clipped to [-grad_clip_value, grad_clip_value].

Notes

Mathematical Formulation:

L-BFGS approximates the inverse Hessian matrix \(H_k^{-1}\) using the limited-memory BFGS update formula. The parameter update is:

\[\theta_{k+1} = \theta_k - \alpha_k H_k^{-1} \nabla f(\theta_k)\]

The inverse Hessian approximation uses \(m\) stored pairs:

\[s_i = \theta_{i+1} - \theta_i \quad \text{(position difference)}\]
\[y_i = \nabla f(\theta_{i+1}) - \nabla f(\theta_i) \quad \text{(gradient difference)}\]

with curvature information:

\[\rho_i = \frac{1}{y_i^T s_i}\]

Two-Loop Recursion Algorithm:

  1. First loop (newest to oldest): Compute direction adjustments

  2. Initial scaling: \(H_0^{-1} = \gamma_k I\) where \(\gamma_k = \frac{s_{k-1}^T y_{k-1}}{y_{k-1}^T y_{k-1}}\)

  3. Second loop (oldest to newest): Apply BFGS corrections

Line Search with Zoom Algorithm:

This implementation includes automatic zoom line search finding step size \(\alpha_k\) satisfying the strong Wolfe conditions for robust convergence.

Key Characteristics:

  • Superlinear convergence: Faster than first-order methods near optimum

  • Memory efficient: O(mn) storage for n parameters, m history size

  • Curvature aware: Uses second-order information without computing Hessian

  • Self-scaling: Adapts to problem geometry automatically

  • Robust line search: Ensures sufficient decrease and curvature conditions

Limitations:

  • Not suitable for stochastic mini-batch optimization

  • Requires full gradients for best performance

  • Memory scales with memory_size × parameter_count

  • Line search requires additional function evaluations

Important Usage Note:

L-BFGS with line search requires additional function evaluations. For best performance, use with optax.value_and_grad_from_state to reuse computations:

>>> import optax
>>> value_and_grad = optax.value_and_grad_from_state(objective)
>>> value, grad = value_and_grad(params, state=opt_state)
>>> updates, opt_state = optimizer.tx.update(
...     grad, opt_state, params,
...     value=value, grad=grad, value_fn=objective
... )

References

Examples

Basic usage for batch optimization:

>>> import brainstate as brainstate
>>> import braintools as braintools
>>>
>>> # Create model
>>> model = brainstate.nn.Linear(10, 5)
>>>
>>> # Initialize L-BFGS optimizer
>>> optimizer = braintools.optim.LBFGS(lr=1.0)
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With custom memory size:

>>> # Larger memory for better Hessian approximation
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=20
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

With smaller memory for efficiency:

>>> # Smaller memory footprint
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=5
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Disabling initial Hessian scaling:

>>> # Without Hessian scaling
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=10,
...     scale_init_hess=False
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))

Fine-tuning example with full-batch training:

>>> import jax.numpy as jnp
>>>
>>> # Setup for fine-tuning
>>> model = brainstate.nn.Sequential(
...     brainstate.nn.Linear(784, 128),
...     brainstate.nn.ReLU(),
...     brainstate.nn.Linear(128, 10)
... )
>>>
>>> # L-BFGS for fine-tuning with full batch
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=10
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Full-batch training step
>>> def train_step(data_x, data_y):
...     def loss_fn():
...         logits = model(data_x)
...         return jnp.mean(
...             braintools.metric.softmax_cross_entropy(logits, data_y)
...         )
...
...     # Compute gradients on full dataset
...     grads = brainstate.transform.grad(loss_fn, model.states(brainstate.ParamState))()
...     optimizer.update(grads)
...     return loss_fn()
>>>
>>> # Use entire dataset (not mini-batch)
>>> x_full = jnp.ones((1000, 784))
>>> y_full = jnp.zeros((1000, 10))
>>> # loss = train_step(x_full, y_full)

Convex optimization example:

>>> # L-BFGS excels at convex problems
>>> model = brainstate.nn.Linear(50, 1)  # Linear regression
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=15
... )
>>> optimizer.register_trainable_weights(model.states(brainstate.ParamState))
>>>
>>> # Typically converges in fewer iterations than first-order methods
>>> for epoch in range(100):
...     # Full-batch gradient computation
...     pass  # training code here

Scientific computing - parameter fitting:

>>> # Fitting exponential decay model
>>> def exponential_model(params, t):
...     return params['A'] * jnp.exp(-params['k'] * t) + params['C']
>>>
>>> def loss_fn(params):
...     predictions = exponential_model(params, time_points)
...     return jnp.mean((predictions - observations) ** 2)
>>>
>>> # L-BFGS for precise parameter estimation
>>> optimizer = braintools.optim.LBFGS(
...     lr=1.0,
...     memory_size=20,  # Higher accuracy for scientific computing
...     scale_init_hess=True
... )

Hybrid optimization strategy:

>>> # Stage 1: Adam for exploration (stochastic)
>>> adam_opt = braintools.optim.Adam(lr=0.001)
>>> for epoch in range(50):
...     for batch in dataloader:
...         grads = compute_batch_gradients(batch)
...         adam_opt.update(grads)
>>>
>>> # Stage 2: L-BFGS for refinement (deterministic)
>>> lbfgs_opt = braintools.optim.LBFGS(lr=1.0, memory_size=20)
>>> for epoch in range(20):
...     grads = compute_full_gradients(full_dataset)
...     lbfgs_opt.update(grads)

Memory size comparison:

>>> # Small memory (fast, less accurate)
>>> opt_small = braintools.optim.LBFGS(memory_size=3)
>>>
>>> # Medium memory (balanced)
>>> opt_medium = braintools.optim.LBFGS(memory_size=10)
>>>
>>> # Large memory (slower, more accurate)
>>> opt_large = braintools.optim.LBFGS(memory_size=30)

See also

SGD

First-order stochastic gradient descent

Adam

Adaptive moment estimation for stochastic optimization

Rprop

Resilient propagation for batch learning

Adagrad

Adaptive gradient algorithm

default_tx()[source]#

Create default gradient transformation with clipping and weight decay.

update(grads, value=None, value_fn=None, **kwargs)[source]#

Update parameters with LBFGS optimizer.

Parameters:
  • grads (dict) – Dictionary of gradients for each parameter.

  • value (float, optional) – Current value of the objective function. Required for linesearch.

  • value_fn (callable, optional) – Function to compute objective value. Required for linesearch.

  • **kwargs – Additional arguments passed to the optimizer update.

Notes

LBFGS requires additional arguments for the linesearch: - value: current objective function value - grad: gradients (automatically passed) - value_fn: callable to evaluate objective function

For best performance, use with optax.value_and_grad_from_state:

>>> value_and_grad = optax.value_and_grad_from_state(loss_fn)
>>> value, grad = value_and_grad(params, state=opt_state)
>>> updates, opt_state = optimizer.update(
...     grad, opt_state, params,
...     value=value, grad=grad, value_fn=loss_fn
... )