NevergradOptimizer#

class braintools.optim.NevergradOptimizer(batched_loss_fun, bounds, n_sample, method='DE', use_nevergrad_recommendation=False, budget=None, num_workers=1, method_params=None)#

Ask/tell optimizer wrapper around Nevergrad with batched evaluation support.

This optimizer draws n_sample candidate parameter sets per iteration (via ask), evaluates them in batch using a user-provided loss function, and reports the losses back to Nevergrad (via tell). It then returns the current best parameters according to either the lowest observed loss or Nevergrad’s recommendation.

Parameters:
  • batched_loss_fun (Callable) –

    Callable evaluating a batch of candidate parameters and returning one scalar loss per candidate. Its signature depends on bounds:

    • If bounds is a sequence/tuple, the callable is invoked as batched_loss_fun(*params) where each element of params is a JAX array stacked over the candidate dimension, e.g., shape (n_sample, ...) per argument.

    • If bounds is a dict, the callable is invoked as batched_loss_fun(**params) where each value is a stacked JAX array of shape (n_sample, ...).

    The return value must be a 1D array-like of length n_sample with the loss per candidate.

  • bounds (Sequence | Dict | None) –

    Search space bounds. Each bound is a pair (min, max). Values can be scalars or arrays (broadcasting not applied), optionally wrapped as brainunit.Quantity to specify units. All leaves within a pair must have identical shapes. Two forms are supported:

    • dict: {"name": (min, max), ...} producing named parameters;

    • sequence/tuple: [(min, max), ...] producing positional parameters passed to batched_loss_fun in the given order.

  • n_sample (int) – Number of candidates to evaluate per iteration.

  • method (str) – Nevergrad optimizer name, e.g. 'DE', 'TwoPointsDE', 'CMA', 'PSO', 'OnePlusOne', or any valid key from nevergrad.optimizers.registry.

  • use_nevergrad_recommendation (bool) – If True, return Nevergrad’s recommendation (based on its internal sampling history) instead of the parameters with the lowest observed loss so far. For very close losses under noise, recommendations can sometimes be preferable.

  • budget (int | None) – Maximum number of evaluations given to Nevergrad. None lets the optimizer run without an explicit budget limit.

  • num_workers (int) – Degree of parallelism hinted to Nevergrad.

  • method_params (Dict | None) – Extra keyword arguments forwarded to the Nevergrad optimizer constructor.

candidates#

History of all parameter sets evaluated (one entry per candidate).

Type:

list

errors#

Aggregated losses corresponding to candidates.

Type:

numpy.ndarray

Examples

Optimize two scalars with tuple bounds and a simple quadratic loss:

>>> import jax.numpy as jnp
>>> def batched_loss_fun(x, y):
...     # x, y have shape (n_sample,)
...     return (x**2 + y**2)
>>> bounds = [(-5.0, 5.0), (-3.0, 3.0)]
>>> opt = NevergradOptimizer(batched_loss_fun, bounds, n_sample=8, method='OnePlusOne')
>>> best = opt.minimize(n_iter=5, verbose=False)
>>> len(best) == 2
True

Optimize named parameters using dict bounds:

>>> def batched_loss_fun(**p):
...     # p['a'], p['b'] have shape (n_sample,)
...     return p['a']**2 + (p['b']-1.0)**2
>>> bounds = {"a": (-5.0, 5.0), "b": (-3.0, 3.0)}
>>> opt = NevergradOptimizer(batched_loss_fun, bounds, n_sample=8, method='DE')
>>> best = opt.minimize(n_iter=3, verbose=False)
>>> set(best.keys()) == {"a", "b"}
True