softmax_cross_entropy_with_integer_labels

softmax_cross_entropy_with_integer_labels#

class braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)#

Compute softmax cross entropy between logits and integer labels.

This is a more efficient version of softmax cross entropy when labels are provided as integer class indices rather than one-hot encoded vectors. Measures the probability error in discrete classification tasks where the classes are mutually exclusive.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Unnormalized log probabilities with shape [..., num_classes]. Must be floating point type.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Integer class indices specifying the correct class for each input. Values should be in the range [0, num_classes). Shape [...]. Must be integer type.

Returns:

Cross entropy between each prediction and the corresponding target distributions, with shape [...].

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> logits = jnp.array([[2.0, 1.0, 0.1]])
>>> labels = jnp.array([0])  # Class 0
>>> loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)
>>> print(loss)
[0.4170299]

References