make_fenchel_young_loss#
- class braintools.metric.make_fenchel_young_loss(max_fun)#
Create a Fenchel-Young loss function from a max function.
Fenchel-Young losses provide a framework for building differentiable loss functions from convex regularizers. They are particularly useful in machine learning for structured prediction tasks and provide a principled way to construct losses that encourage sparsity or specific structure in predictions.
The Fenchel-Young loss is defined as:
\[\ell_{FY}(y, \theta) = \Omega(\theta) - \langle y, \theta \rangle\]where \(\Omega\) is a convex regularizer (the max function), \(\theta\) are the scores, and \(y\) are the targets.
- Parameters:
max_fun (
MaxFun) – The max function (convex regularizer) on which the Fenchel-Young loss is built. Common choices includejax.scipy.special.logsumexpfor softmax-based losses or custom max functions for structured outputs.- Returns:
A Fenchel-Young loss function with signature
fenchel_young_loss(scores, targets, *args, **kwargs)that computes the loss between scores and targets.- Return type:
callable
Notes
Warning
The resulting loss operates over the last dimension of the input arrays and accepts arbitrary leading dimensions. This differs from some other implementations that flatten inputs into 1D vectors.
The choice of max function determines the properties of the resulting loss:
logsumexp: Creates a softmax-based cross-entropy lossmax: Creates a max-margin lossCustom functions: Can create structured losses for specific applications
Examples
Create a softmax-based Fenchel-Young loss:
>>> import jax.numpy as jnp >>> from jax.scipy.special import logsumexp >>> import braintools as braintools >>> # Create the loss function >>> fy_loss = braintools.metric.make_fenchel_young_loss(max_fun=logsumexp) >>> # Example usage >>> scores = jnp.array([[2.0, 1.0, 0.5], [1.5, 2.5, 1.0]]) >>> targets = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> loss = fy_loss(scores, targets) >>> print(f"Fenchel-Young loss: {loss}")
Create a custom max function for structured prediction:
>>> def custom_max(x): ... return jnp.max(x) + 0.1 * jnp.sum(x**2) # L2 regularized max >>> structured_loss = braintools.metric.make_fenchel_young_loss(max_fun=custom_max)
See also
jax.scipy.special.logsumexpCommon choice for softmax-based losses
braintools.metric.sigmoid_binary_cross_entropyAlternative binary loss
References