softmax_cross_entropy

softmax_cross_entropy#

class braintools.metric.softmax_cross_entropy(logits, labels)#

Compute the softmax cross entropy between logits and labels.

Measures the probability error in discrete classification tasks where the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Unnormalized log probabilities with shape [..., num_classes].

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Valid probability distributions (non-negative, sum to 1), e.g., a one-hot encoding specifying the correct class for each input. Must have shape broadcastable to [..., num_classes].

Returns:

Cross entropy between each prediction and the corresponding target distributions, with shape [...].

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> logits = jnp.array([[2.0, 1.0, 0.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0]])
>>> loss = braintools.metric.softmax_cross_entropy(logits, labels)
>>> print(loss)
[0.4170299]

References