sigmoid_focal_loss#
- class braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)#
Compute sigmoid focal loss for addressing class imbalance.
Focal loss is designed to address class imbalance in dense object detection by down-weighting easy examples and focusing training on hard negatives. It applies a modulating factor to the cross entropy loss to reduce the loss contribution from easy examples.
The focal loss is defined as:
\[\begin{split}FL(p_t) = -\\alpha_t (1 - p_t)^\\gamma \\log(p_t)\end{split}\]where \(p_t\) is the predicted probability for the true class, \(\\alpha_t\) is a class-dependent weighting factor, and \(\\gamma\) is the focusing parameter.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Unnormalized predictions (logits) for binary classification. Can have any shape for element-wise binary predictions.labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Binary labels with values in {0, 1}. Must have the same shape as logits. Use 1 for positive class, 0 for negative class.alpha (
float|None) – Weighting factor in range (0, 1) to balance positive vs negative examples. If None, no class-based weighting is applied.gamma (
float) – Focusing parameter (exponent) that controls the rate at which easy examples are down-weighted. Higher values focus more on hard examples. Common values: 0.5, 1.0, 2.0, 5.0.
- Returns:
Focal loss values with the same shape as input logits and labels.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> # Imbalanced binary classification >>> logits = jnp.array([2.0, -1.0, 0.5, -2.0]) >>> labels = jnp.array([1.0, 0.0, 1.0, 0.0]) >>> # Standard focal loss >>> loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0) >>> print(f"Focal loss: {loss}") >>> # Compare with unweighted version >>> loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)
Notes
Use this loss function when classes are not mutually exclusive (multi-label classification) or when dealing with severe class imbalance. For mutually exclusive classes, consider using softmax-based focal loss variants.
The alpha parameter is typically set to the inverse class frequency for the positive class, e.g., alpha=0.25 when positive examples are 25% of data.
References