brainmass.objectives.combine

Contents

brainmass.objectives.combine#

brainmass.objectives.combine(*weighted_objectives)[source]#

Combine weighted objective callables into a single objective.

Parameters:

*weighted_objectives (tuple of (float, callable)) – Pairs of (weight, objective) where each objective is a callable returned by the builders in this module (signature objective(prediction, target)).

Returns:

loss(prediction, target=None) -> scalar equal to sum(weight * objective(prediction, target)).

Return type:

callable

Examples

>>> import jax.numpy as jnp
>>> from brainmass import objectives
>>> loss = objectives.combine(
...     (2.0, objectives.timeseries_rmse()),
...     (0.5, objectives.timeseries_rmse()),
... )
>>> x = jnp.zeros((10, 3))
>>> float(loss(x + 1.0, x))   # (2.0 + 0.5) * 1.0
2.5