sigmoid_focal_loss

sigmoid_focal_loss#

class braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)#

Compute sigmoid focal loss for addressing class imbalance.

Focal loss is designed to address class imbalance in dense object detection by down-weighting easy examples and focusing training on hard negatives. It applies a modulating factor to the cross entropy loss to reduce the loss contribution from easy examples.

The focal loss is defined as:

\[\begin{split}FL(p_t) = -\\alpha_t (1 - p_t)^\\gamma \\log(p_t)\end{split}\]

where \(p_t\) is the predicted probability for the true class, \(\\alpha_t\) is a class-dependent weighting factor, and \(\\gamma\) is the focusing parameter.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Unnormalized predictions (logits) for binary classification. Can have any shape for element-wise binary predictions.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Binary labels with values in {0, 1}. Must have the same shape as logits. Use 1 for positive class, 0 for negative class.

  • alpha (float | None) – Weighting factor in range (0, 1) to balance positive vs negative examples. If None, no class-based weighting is applied.

  • gamma (float) – Focusing parameter (exponent) that controls the rate at which easy examples are down-weighted. Higher values focus more on hard examples. Common values: 0.5, 1.0, 2.0, 5.0.

Returns:

Focal loss values with the same shape as input logits and labels.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Imbalanced binary classification
>>> logits = jnp.array([2.0, -1.0, 0.5, -2.0])
>>> labels = jnp.array([1.0, 0.0, 1.0, 0.0])
>>> # Standard focal loss
>>> loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)
    >>> print(f"Focal loss: {loss}")
>>> # Compare with unweighted version
>>> loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)

Notes

Use this loss function when classes are not mutually exclusive (multi-label classification) or when dealing with severe class imbalance. For mutually exclusive classes, consider using softmax-based focal loss variants.

The alpha parameter is typically set to the inverse class frequency for the positive class, e.g., alpha=0.25 when positive examples are 25% of data.

References