make_fenchel_young_loss

make_fenchel_young_loss#

class braintools.metric.make_fenchel_young_loss(max_fun)#

Create a Fenchel-Young loss function from a max function.

Fenchel-Young losses provide a framework for building differentiable loss functions from convex regularizers. They are particularly useful in machine learning for structured prediction tasks and provide a principled way to construct losses that encourage sparsity or specific structure in predictions.

The Fenchel-Young loss is defined as:

\[\ell_{FY}(y, \theta) = \Omega(\theta) - \langle y, \theta \rangle\]

where \(\Omega\) is a convex regularizer (the max function), \(\theta\) are the scores, and \(y\) are the targets.

Parameters:

max_fun (MaxFun) – The max function (convex regularizer) on which the Fenchel-Young loss is built. Common choices include jax.scipy.special.logsumexp for softmax-based losses or custom max functions for structured outputs.

Returns:

A Fenchel-Young loss function with signature fenchel_young_loss(scores, targets, *args, **kwargs) that computes the loss between scores and targets.

Return type:

callable

Notes

Warning

The resulting loss operates over the last dimension of the input arrays and accepts arbitrary leading dimensions. This differs from some other implementations that flatten inputs into 1D vectors.

The choice of max function determines the properties of the resulting loss:

  • logsumexp: Creates a softmax-based cross-entropy loss

  • max: Creates a max-margin loss

  • Custom functions: Can create structured losses for specific applications

Examples

Create a softmax-based Fenchel-Young loss:

>>> import jax.numpy as jnp
>>> from jax.scipy.special import logsumexp
>>> import braintools as braintools
>>> # Create the loss function
>>> fy_loss = braintools.metric.make_fenchel_young_loss(max_fun=logsumexp)
>>> # Example usage
>>> scores = jnp.array([[2.0, 1.0, 0.5], [1.5, 2.5, 1.0]])
>>> targets = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> loss = fy_loss(scores, targets)
>>> print(f"Fenchel-Young loss: {loss}")

Create a custom max function for structured prediction:

>>> def custom_max(x):
...     return jnp.max(x) + 0.1 * jnp.sum(x**2)  # L2 regularized max
>>> structured_loss = braintools.metric.make_fenchel_young_loss(max_fun=custom_max)

See also

jax.scipy.special.logsumexp

Common choice for softmax-based losses

braintools.metric.sigmoid_binary_cross_entropy

Alternative binary loss

References