Tutorial 1: Classification Losses#

This tutorial introduces the core classification losses in braintools.metric and how to pick and use them effectively:

  • Binary/multilabel: braintools.metric.sigmoid_binary_cross_entropy

  • Multiclass: braintools.metric.softmax_cross_entropy, braintools.metric.softmax_cross_entropy_with_integer_labels

  • Imbalanced data: braintools.metric.sigmoid_focal_loss

  • Regularization: braintools.metric.smooth_labels

All examples use JAX arrays and are shape- and type-checked.

import jax.numpy as jnp
import braintools

Binary / Multilabel (Sigmoid Cross-Entropy)#

Use when each class is an independent yes/no decision (not mutually exclusive).

# Binary/multilabel setup (elementwise BCE)
logits = jnp.array([1.0, -1.0, 0.0])        # unnormalized logits
labels = jnp.array([1.0,  0.0, 1.0])        # binary targets in {0,1}

loss = braintools.metric.sigmoid_binary_cross_entropy(logits, labels)
print(loss)          # per-element BCE, same shape as logits
print(loss.mean())   # common reduction
[0.3132617 0.3132617 0.6931472]
0.4398902

Tips

  • Labels must be float (0/1) or probabilities, shape-broadcastable with logits.

  • Use this also for multilabel multiclass problems.

Multiclass (Softmax Cross-Entropy)#

Use when classes are mutually exclusive.

# One-hot or probability targets
# logits shape [..., num_classes]
logits = jnp.array([[2.0, 1.0, 0.1]])
targets = jnp.array([[1.0, 0.0, 0.0]])  # one-hot

loss = braintools.metric.softmax_cross_entropy(logits, targets)
print(loss)  # shape [...]
[0.41702995]
# Integer labels (preferred for single-label classification)
logits = jnp.array([[2.0, 1.0, 0.1]])
labels = jnp.array([0])  # class index

loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)
print(loss)
[0.41702995]

Notes

  • softmax_cross_entropy expects float targets; *_with_integer_labels expects integer labels.

  • Shapes: logits[..., C], labels[...] or targets[..., C].

Focal Loss (Imbalanced Data)#

Use focal loss to focus learning on hard, misclassified examples in imbalanced settings. For multilabel/binary, use sigmoid_focal_loss.

# Imbalanced binary classification example
logits = jnp.array([2.0, -1.0, 0.5, -2.0])
labels = jnp.array([1.0, 0.0, 1.0, 0.0])

# Alpha balances positive vs negative; gamma focuses on hard examples
loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)
print(loss)
print(loss.mean())

# Unweighted focal (no class weighting)
loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)
print(loss_unweighted)
[0.00045089 0.01699354 0.01689337 0.00135267]
0.008922619
[0.00180356 0.02265805 0.06757348 0.00180356]

Guidance

  • Choose alpha near the positive-class frequency (e.g., 0.25).

  • Increase gamma (1–5) to emphasize harder examples more.

Label Smoothing (Regularization)#

Smoothing prevents overconfidence by blending one-hot labels with a uniform distribution. Combine with softmax CE.

# One-hot labels [..., C]
labels_one_hot = jnp.array([[1.0, 0.0, 0.0]])
smoothed = braintools.metric.smooth_labels(labels_one_hot, alpha=0.1)

logits = jnp.array([[2.0, 1.0, 0.5]])
loss = braintools.metric.softmax_cross_entropy(logits, smoothed)
print(loss)
print(smoothed)
[0.54770213]
[[0.93333334 0.03333334 0.03333334]]

Rules of thumb

  • alpha in [0.05, 0.2] is common; larger values smooth more.

  • Improves calibration and robustness; may slightly lower peak accuracy if overused.

Which Loss Should I Use?#

  • Binary/multilabel tasks → sigmoid_binary_cross_entropy

  • Multiclass single-label → softmax_cross_entropy_with_integer_labels

  • Heavily imbalanced (binary/multilabel) → sigmoid_focal_loss

  • Overconfident models → smooth_labels + softmax CE

Common Pitfalls#

  • Do not feed integer labels to softmax_cross_entropy (use the integer-label variant).

  • For multilabel problems, use sigmoid-based losses (not softmax).

  • Always match shapes: logits[..., C] with one-hot targets [..., C] or integer [...].

Extras: NLL with Log-Probabilities#

If you already have log-probabilities, use nll_loss directly.

log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2]))
target = 1
print(braintools.metric.nll_loss(log_probs, target))
-0.35667497

See also

  • API reference: braintools.metric → classification functions

  • Ranking and regression losses are covered in separate tutorials.