NevergradOptimizer#
- class braintools.optim.NevergradOptimizer(batched_loss_fun, bounds, n_sample, method='DE', use_nevergrad_recommendation=False, budget=None, num_workers=1, method_params=None)#
Ask/tell optimizer wrapper around Nevergrad with batched evaluation support.
This optimizer draws
n_samplecandidate parameter sets per iteration (viaask), evaluates them in batch using a user-provided loss function, and reports the losses back to Nevergrad (viatell). It then returns the current best parameters according to either the lowest observed loss or Nevergrad’s recommendation.- Parameters:
batched_loss_fun (
Callable) –Callable evaluating a batch of candidate parameters and returning one scalar loss per candidate. Its signature depends on
bounds:If
boundsis a sequence/tuple, the callable is invoked asbatched_loss_fun(*params)where each element ofparamsis a JAX array stacked over the candidate dimension, e.g., shape(n_sample, ...)per argument.If
boundsis a dict, the callable is invoked asbatched_loss_fun(**params)where each value is a stacked JAX array of shape(n_sample, ...).
The return value must be a 1D array-like of length
n_samplewith the loss per candidate.bounds (
Sequence|Dict|None) –Search space bounds. Each bound is a pair
(min, max). Values can be scalars or arrays (broadcasting not applied), optionally wrapped asbrainunit.Quantityto specify units. All leaves within a pair must have identical shapes. Two forms are supported:dict:
{"name": (min, max), ...}producing named parameters;sequence/tuple:
[(min, max), ...]producing positional parameters passed tobatched_loss_funin the given order.
n_sample (
int) – Number of candidates to evaluate per iteration.method (
str) – Nevergrad optimizer name, e.g.'DE','TwoPointsDE','CMA','PSO','OnePlusOne', or any valid key fromnevergrad.optimizers.registry.use_nevergrad_recommendation (
bool) – If True, return Nevergrad’s recommendation (based on its internal sampling history) instead of the parameters with the lowest observed loss so far. For very close losses under noise, recommendations can sometimes be preferable.budget (
int|None) – Maximum number of evaluations given to Nevergrad.Nonelets the optimizer run without an explicit budget limit.num_workers (
int) – Degree of parallelism hinted to Nevergrad.method_params (
Dict|None) – Extra keyword arguments forwarded to the Nevergrad optimizer constructor.
- errors#
Aggregated losses corresponding to
candidates.- Type:
numpy.ndarray
Examples
Optimize two scalars with tuple bounds and a simple quadratic loss:
>>> import jax.numpy as jnp >>> def batched_loss_fun(x, y): ... # x, y have shape (n_sample,) ... return (x**2 + y**2) >>> bounds = [(-5.0, 5.0), (-3.0, 3.0)] >>> opt = NevergradOptimizer(batched_loss_fun, bounds, n_sample=8, method='OnePlusOne') >>> best = opt.minimize(n_iter=5, verbose=False) >>> len(best) == 2 True
Optimize named parameters using dict bounds:
>>> def batched_loss_fun(**p): ... # p['a'], p['b'] have shape (n_sample,) ... return p['a']**2 + (p['b']-1.0)**2 >>> bounds = {"a": (-5.0, 5.0), "b": (-3.0, 3.0)} >>> opt = NevergradOptimizer(batched_loss_fun, bounds, n_sample=8, method='DE') >>> best = opt.minimize(n_iter=3, verbose=False) >>> set(best.keys()) == {"a", "b"} True