brainmass.objectives.combine#
- brainmass.objectives.combine(*weighted_objectives)[source]#
Combine weighted objective callables into a single objective.
- Parameters:
*weighted_objectives (tuple of (float, callable)) – Pairs of
(weight, objective)where eachobjectiveis a callable returned by the builders in this module (signatureobjective(prediction, target)).- Returns:
loss(prediction, target=None) -> scalarequal tosum(weight * objective(prediction, target)).- Return type:
callable
Examples
>>> import jax.numpy as jnp >>> from brainmass import objectives >>> loss = objectives.combine( ... (2.0, objectives.timeseries_rmse()), ... (0.5, objectives.timeseries_rmse()), ... ) >>> x = jnp.zeros((10, 3)) >>> float(loss(x + 1.0, x)) # (2.0 + 0.5) * 1.0 2.5