poly_loss_cross_entropy#
- class braintools.metric.poly_loss_cross_entropy(logits, labels, epsilon=2.0)#
Compute PolyLoss cross entropy between logits and labels.
PolyLoss is a polynomial expansion of commonly used classification loss functions. It decomposes loss functions into weighted polynomial bases inspired by the Taylor expansion of cross-entropy and focal loss.
The PolyLoss is defined as:
\[L_{Poly} = \sum_{j=1}^\infty \alpha_j \cdot (1 - P_t)^j\]This function implements a simplified version with only the first polynomial term modified:
\[L = -\log(P_t) + \epsilon \cdot (1 - P_t)\]where \(P_t\) is the predicted probability for the true class.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Unnormalized log probabilities with shape[..., num_classes].labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Valid probability distributions (non-negative, sum to 1), e.g., a one-hot encoding specifying the correct class for each input. Must have shape broadcastable to[..., num_classes].epsilon (
float) –Coefficient of the first polynomial term. Controls the emphasis on difficult examples:
For ImageNet 2D classification:
epsilon = 2.0(recommended)For 2D instance segmentation/object detection:
epsilon = -1.0Task-specific tuning via grid search is recommended
- Returns:
PolyLoss values between each prediction and corresponding target distributions, with shape
[...].- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> logits = jnp.array([[2.0, 1.0, 0.1]]) >>> labels = jnp.array([[1.0, 0.0, 0.0]]) >>> loss = braintools.metric.poly_loss_cross_entropy(logits, labels, epsilon=2.0) >>> print(f"PolyLoss: {loss[0]:.4f}")
Notes
PolyLoss can improve model calibration and performance on imbalanced datasets by adjusting the emphasis on difficult examples through the epsilon parameter.
References