smooth_labels#
- class braintools.metric.smooth_labels(labels, alpha)#
Apply label smoothing regularization to one-hot encoded labels.
Label smoothing is a regularization technique that prevents neural networks from becoming overconfident in their predictions by introducing controlled uncertainty in the training labels. This technique replaces hard targets with a weighted mixture of the original one-hot labels and a uniform distribution over all classes.
The smoothing transformation is defined as:
\[\tilde{y}_k = (1 - \alpha) y_k + \frac{\alpha}{K}\]where \(y_k\) is the original label for class \(k\), \(\alpha\) is the smoothing parameter, \(K\) is the number of classes, and \(\tilde{y}_k\) is the smoothed label.
- Parameters:
labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – One-hot encoded labels with shape(..., num_classes)where the last dimension represents class probabilities. Must be floating-point type. Each row should contain exactly one 1.0 and zeros elsewhere for proper one-hot encoding.alpha (
float) –Smoothing parameter in the range [0, 1] controlling the degree of smoothing:
alpha = 0.0: No smoothing (original hard labels)alpha = 0.1: Light smoothing (common choice)alpha = 1.0: Maximum smoothing (uniform distribution)
Typical values range from 0.05 to 0.2 depending on the task complexity.
- Returns:
Smoothed label distribution with the same shape as input. Each row sums to 1.0 and contains the smoothed probability distribution over classes.
- Return type:
Array
Notes
Label smoothing provides several benefits:
Improved calibration: Reduces overconfident predictions
Better generalization: Acts as regularization to prevent overfitting
Robustness: Less sensitive to label noise and annotation errors
Gradient stability: Provides more stable training dynamics
The technique is particularly effective for:
Image classification with large numbers of classes
Tasks with potential label ambiguity or noise
Training very deep networks prone to overconfidence
Knowledge distillation scenarios
Common usage patterns:
Use with cross-entropy loss for classification
Combine with other regularization techniques (dropout, weight decay)
Tune alpha based on validation performance
Examples
Basic label smoothing for 3-class classification:
>>> import jax.numpy as jnp >>> import braintools as braintools >>> # One-hot labels for 2 samples, 3 classes >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> smoothed = braintools.metric.smooth_labels(labels, alpha=0.1) >>> print("Original:") >>> print(labels) >>> print("Smoothed:") >>> print(smoothed) [[0.93333334 0.03333333 0.03333333] [0.03333333 0.93333334 0.03333333]]
Effect of different smoothing parameters:
>>> single_label = jnp.array([[1.0, 0.0, 0.0]]) >>> # Light smoothing >>> light = braintools.metric.smooth_labels(single_label, alpha=0.05) >>> print(f"Light smoothing (α=0.05): {light[0]}") >>> # Moderate smoothing >>> moderate = braintools.metric.smooth_labels(single_label, alpha=0.1) >>> print(f"Moderate smoothing (α=0.1): {moderate[0]}") >>> # Heavy smoothing >>> heavy = braintools.metric.smooth_labels(single_label, alpha=0.5) >>> print(f"Heavy smoothing (α=0.5): {heavy[0]}")
Batch processing with different numbers of classes:
>>> # 5-class problem >>> labels_5class = jnp.eye(5) # Identity matrix as one-hot labels >>> smoothed_5class = braintools.metric.smooth_labels(labels_5class, alpha=0.1) >>> print(f"5-class smoothed shape: {smoothed_5class.shape}") >>> print(f"Row sum (should be ~1.0): {jnp.sum(smoothed_5class[0]):.6f}")
Integration with cross-entropy loss:
>>> # Typical usage in training loop >>> logits = jnp.array([[2.0, 1.0, 0.5]]) # Model predictions >>> hard_labels = jnp.array([[1.0, 0.0, 0.0]]) >>> smooth_labels_result = braintools.metric.smooth_labels(hard_labels, alpha=0.1) >>> # Use smooth_labels_result with cross-entropy loss function
Verify probability distribution properties:
>>> smoothed = braintools.metric.smooth_labels(jnp.eye(4), alpha=0.2) >>> print(f"All rows sum to 1: {jnp.allclose(jnp.sum(smoothed, axis=1), 1.0)}") >>> print(f"All values non-negative: {jnp.all(smoothed >= 0)}")
See also
braintools.metric.sigmoid_binary_cross_entropyBinary classification loss
jax.nn.softmax_cross_entropyStandard cross-entropy with smoothed labels
jax.numpy.eyeCreate one-hot encoded labels
References