poly_loss_cross_entropy

poly_loss_cross_entropy#

class braintools.metric.poly_loss_cross_entropy(logits, labels, epsilon=2.0)#

Compute PolyLoss cross entropy between logits and labels.

PolyLoss is a polynomial expansion of commonly used classification loss functions. It decomposes loss functions into weighted polynomial bases inspired by the Taylor expansion of cross-entropy and focal loss.

The PolyLoss is defined as:

\[L_{Poly} = \sum_{j=1}^\infty \alpha_j \cdot (1 - P_t)^j\]

This function implements a simplified version with only the first polynomial term modified:

\[L = -\log(P_t) + \epsilon \cdot (1 - P_t)\]

where \(P_t\) is the predicted probability for the true class.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Unnormalized log probabilities with shape [..., num_classes].

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Valid probability distributions (non-negative, sum to 1), e.g., a one-hot encoding specifying the correct class for each input. Must have shape broadcastable to [..., num_classes].

  • epsilon (float) –

    Coefficient of the first polynomial term. Controls the emphasis on difficult examples:

    • For ImageNet 2D classification: epsilon = 2.0 (recommended)

    • For 2D instance segmentation/object detection: epsilon = -1.0

    • Task-specific tuning via grid search is recommended

Returns:

PolyLoss values between each prediction and corresponding target distributions, with shape [...].

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> logits = jnp.array([[2.0, 1.0, 0.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0]])
>>> loss = braintools.metric.poly_loss_cross_entropy(logits, labels, epsilon=2.0)
>>> print(f"PolyLoss: {loss[0]:.4f}")

Notes

PolyLoss can improve model calibration and performance on imbalanced datasets by adjusting the emphasis on difficult examples through the epsilon parameter.

References