softmax_cross_entropy_with_integer_labels#
- class braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)#
Compute softmax cross entropy between logits and integer labels.
This is a more efficient version of softmax cross entropy when labels are provided as integer class indices rather than one-hot encoded vectors. Measures the probability error in discrete classification tasks where the classes are mutually exclusive.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Unnormalized log probabilities with shape[..., num_classes]. Must be floating point type.labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Integer class indices specifying the correct class for each input. Values should be in the range[0, num_classes). Shape[...]. Must be integer type.
- Returns:
Cross entropy between each prediction and the corresponding target distributions, with shape
[...].- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> logits = jnp.array([[2.0, 1.0, 0.1]]) >>> labels = jnp.array([0]) # Class 0 >>> loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels) >>> print(loss) [0.4170299]
References