multiclass_perceptron_loss

multiclass_perceptron_loss#

class braintools.metric.multiclass_perceptron_loss(scores, labels)#

Compute multiclass perceptron loss for classification.

The multiclass perceptron loss measures the difference between the highest scoring class and the correct class score. It is used in structured perceptron learning for multiclass classification.

The loss is defined as:

\[L = \max_j s_j - s_y\]

where \(s_y\) is the score for the correct class \(y\) and \(\max_j s_j\) is the maximum score across all classes.

Parameters:
  • scores (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Model output scores with shape [..., num_classes]. These are raw scores (not probabilities) for each class.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Ground-truth integer class labels with shape [...]. Each label should be in the range [0, num_classes).

Returns:

Perceptron loss values with shape [...], same as the leading dimensions of scores.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Scores for 3 classes, 2 samples
>>> scores = jnp.array([[1.0, 2.0, 0.5], [0.8, 0.3, 1.2]])
>>> labels = jnp.array([1, 2])  # Correct classes
>>> loss = braintools.metric.multiclass_perceptron_loss(scores, labels)
>>> print(loss)
[0.  0. ]

References