Tutorial 1: Classification Losses#
This tutorial introduces the core classification losses in braintools.metric and how to pick and use them effectively:
Binary/multilabel:
braintools.metric.sigmoid_binary_cross_entropyMulticlass:
braintools.metric.softmax_cross_entropy,braintools.metric.softmax_cross_entropy_with_integer_labelsImbalanced data:
braintools.metric.sigmoid_focal_lossRegularization:
braintools.metric.smooth_labels
All examples use JAX arrays and are shape- and type-checked.
import jax.numpy as jnp
import braintools
Binary / Multilabel (Sigmoid Cross-Entropy)#
Use when each class is an independent yes/no decision (not mutually exclusive).
# Binary/multilabel setup (elementwise BCE)
logits = jnp.array([1.0, -1.0, 0.0]) # unnormalized logits
labels = jnp.array([1.0, 0.0, 1.0]) # binary targets in {0,1}
loss = braintools.metric.sigmoid_binary_cross_entropy(logits, labels)
print(loss) # per-element BCE, same shape as logits
print(loss.mean()) # common reduction
[0.3132617 0.3132617 0.6931472]
0.4398902
Tips
Labels must be float (0/1) or probabilities, shape-broadcastable with
logits.Use this also for multilabel multiclass problems.
Multiclass (Softmax Cross-Entropy)#
Use when classes are mutually exclusive.
# One-hot or probability targets
# logits shape [..., num_classes]
logits = jnp.array([[2.0, 1.0, 0.1]])
targets = jnp.array([[1.0, 0.0, 0.0]]) # one-hot
loss = braintools.metric.softmax_cross_entropy(logits, targets)
print(loss) # shape [...]
[0.41702995]
# Integer labels (preferred for single-label classification)
logits = jnp.array([[2.0, 1.0, 0.1]])
labels = jnp.array([0]) # class index
loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)
print(loss)
[0.41702995]
Notes
softmax_cross_entropyexpects float targets;*_with_integer_labelsexpects integer labels.Shapes:
logits[..., C],labels[...]ortargets[..., C].
Focal Loss (Imbalanced Data)#
Use focal loss to focus learning on hard, misclassified examples in imbalanced settings. For multilabel/binary, use sigmoid_focal_loss.
# Imbalanced binary classification example
logits = jnp.array([2.0, -1.0, 0.5, -2.0])
labels = jnp.array([1.0, 0.0, 1.0, 0.0])
# Alpha balances positive vs negative; gamma focuses on hard examples
loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)
print(loss)
print(loss.mean())
# Unweighted focal (no class weighting)
loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)
print(loss_unweighted)
[0.00045089 0.01699354 0.01689337 0.00135267]
0.008922619
[0.00180356 0.02265805 0.06757348 0.00180356]
Guidance
Choose
alphanear the positive-class frequency (e.g., 0.25).Increase
gamma(1–5) to emphasize harder examples more.
Label Smoothing (Regularization)#
Smoothing prevents overconfidence by blending one-hot labels with a uniform distribution. Combine with softmax CE.
# One-hot labels [..., C]
labels_one_hot = jnp.array([[1.0, 0.0, 0.0]])
smoothed = braintools.metric.smooth_labels(labels_one_hot, alpha=0.1)
logits = jnp.array([[2.0, 1.0, 0.5]])
loss = braintools.metric.softmax_cross_entropy(logits, smoothed)
print(loss)
print(smoothed)
[0.54770213]
[[0.93333334 0.03333334 0.03333334]]
Rules of thumb
alphain [0.05, 0.2] is common; larger values smooth more.Improves calibration and robustness; may slightly lower peak accuracy if overused.
Which Loss Should I Use?#
Binary/multilabel tasks →
sigmoid_binary_cross_entropyMulticlass single-label →
softmax_cross_entropy_with_integer_labelsHeavily imbalanced (binary/multilabel) →
sigmoid_focal_lossOverconfident models →
smooth_labels+ softmax CE
Common Pitfalls#
Do not feed integer labels to
softmax_cross_entropy(use the integer-label variant).For multilabel problems, use sigmoid-based losses (not softmax).
Always match shapes:
logits[..., C]with one-hot targets[..., C]or integer[...].
Extras: NLL with Log-Probabilities#
If you already have log-probabilities, use nll_loss directly.
log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2]))
target = 1
print(braintools.metric.nll_loss(log_probs, target))
-0.35667497
See also
API reference:
braintools.metric→ classification functionsRanking and regression losses are covered in separate tutorials.