multiclass_hinge_loss

multiclass_hinge_loss#

class braintools.metric.multiclass_hinge_loss(scores, labels)#

Compute multiclass hinge loss for classification.

The multiclass hinge loss is an extension of the binary hinge loss to multiple classes. It encourages the correct class score to be at least 1 unit higher than the highest scoring incorrect class.

The loss is defined as:

\[L = \max(0, \max_{j \neq y} s_j - s_y + 1)\]

where \(s_y\) is the score for the correct class \(y\) and \(s_j\) are scores for other classes.

Parameters:
  • scores (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Model output scores with shape [..., num_classes]. These are raw scores (not probabilities) for each class.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Ground-truth integer class labels with shape [...]. Each label should be in the range [0, num_classes).

Returns:

Hinge loss values with shape [...], same as the leading dimensions of scores.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Scores for 3 classes, 2 samples
>>> scores = jnp.array([[1.0, 2.0, 0.5], [0.8, 0.3, 1.2]])
>>> labels = jnp.array([1, 2])  # Correct classes
>>> loss = braintools.metric.multiclass_hinge_loss(scores, labels)
>>> print(loss)
[0.  0. ]

References