Tutorial 2: Regression Losses#
This tutorial shows how to use core regression losses in braintools.metric and when to choose each:
L1 / MAE:
absolute_error,l1_lossL2 / MSE:
squared_error,l2_lossRobust:
huber_loss,log_coshEmbeddings:
cosine_distance(andcosine_similarity)
Most functions support reductions like 'none'|'mean'|'sum' and optional axis for per-sample aggregation.
import jax.numpy as jnp
import braintools as bt
Setup: sample predictions and targets#
We’ll use simple arrays for clarity; in practice these are model outputs and labels.
y_pred = jnp.array([[1.0, 2.0, 3.0],
[2.0, 2.5, 2.0]])
y_true = jnp.array([[1.1, 1.9, 3.2],
[2.0, 2.0, 2.0]])
y_outlier = jnp.array([[1.0, 2.0, 10.0],
[2.0, 2.5, -5.0]]) # to show robustness
L1 loss (Mean Absolute Error)#
Use L1 when robustness to outliers is important.
# Elementwise absolute error, then mean over last axis (per-sample MAE)
mae_per_sample = bt.metric.absolute_error(y_pred, y_true, axis=-1, reduction='mean')
print('MAE per sample:', mae_per_sample)
# Direct L1 loss API (commonly returns mean by default)
l1 = bt.metric.l1_loss(y_pred, y_true)
print('l1_loss (mean):', l1)
# Outlier comparison
print('MAE with outlier:', bt.metric.absolute_error(y_outlier, y_true, axis=-1, reduction='mean'))
MAE per sample: [0.13333337 0.16666667]
l1_loss (mean): 0.9000001
MAE with outlier: [2.3333335 2.5 ]
L2 loss (Mean Squared Error)#
Use L2 when larger errors should be penalized more heavily.
# Squared error mean over last axis (per-sample MSE)
mse_per_sample = bt.metric.squared_error(y_pred, y_true, axis=-1, reduction='mean')
print('MSE per sample:', mse_per_sample)
# Direct L2 loss API
l2 = bt.metric.l2_loss(y_pred, y_true)
print('l2_loss (mean):', l2)
# Outlier comparison
print('MSE with outlier:', bt.metric.squared_error(y_outlier, y_true, axis=-1, reduction='mean'))
MSE per sample: [0.02000001 0.08333334]
l2_loss (mean): [[0.005 0.005 0.02000001]
[0. 0.125 0. ]]
MSE with outlier: [15.420001 16.416668]
Huber loss (robust L2)#
Huber behaves like L2 near zero and L1 for large residuals; set delta to tune the transition.
huber = bt.metric.huber_loss(y_pred, y_true, delta=1.0)
huber_outlier = bt.metric.huber_loss(y_outlier, y_true, delta=1.0)
print('Huber (mean):', huber)
print('Huber with outlier (mean):', huber_outlier)
Huber (mean): [[0.005 0.005 0.02000001]
[0. 0.125 0. ]]
Huber with outlier (mean): [[5.000002e-03 5.000002e-03 6.300000e+00]
[0.000000e+00 1.250000e-01 6.500000e+00]]
log-cosh (smooth robust loss)#
log_cosh is a smooth approximation to L1 that is less sensitive than L2 to outliers.
lc = bt.metric.log_cosh(y_pred - y_true)
lc_outlier = bt.metric.log_cosh(y_outlier - y_true)
print('log-cosh (mean):', lc)
print('log-cosh with outlier (mean):', lc_outlier)
log-cosh (mean): [[0.00499171 0.00499171 0.01986814]
[0. 0.12011451 0. ]]
log-cosh with outlier (mean): [[4.99171019e-03 4.99171019e-03 6.10685444e+00]
[0.00000000e+00 1.20114505e-01 6.30685377e+00]]
Cosine distance (1 - cosine similarity)#
Use for comparing directions of vectors (embeddings). Scale-invariant and bounded.
# Pairwise aligned vectors [..., D] -> [...]
v1 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
v2 = jnp.array([[0.0, 1.0], [1.0, 0.0], [1.0, -1.0]])
cd = bt.metric.cosine_distance(v1, v2, epsilon=1e-8)
print('Cosine distance:', cd)
# Also available: cosine_similarity (aligned) and pairwise matrix version in braintools.metric.cosine_similarity (X,Y)
cs_aligned = bt.metric.cosine_similarity(v1, v2)
print('Cosine similarity:', cs_aligned)
Cosine distance: [1. 1. 1.]
Cosine similarity: [0. 0. 0.]
Guidance#
Prefer L1/Huber/log-cosh when outliers are present or robustness is desired.
Use L2/MSE for well-behaved noise where larger errors should be penalized quadratically.
For embeddings, normalize implicitly via cosine distance; no need to re-scale features.
Use
axisto aggregate per-sample (e.g.,axis=-1) and setreductionexplicitly when needed.
Pitfalls#
Ensure predictions and targets have the same shape for arithmetic losses.
For cosine metrics, avoid zero vectors or set a small
epsilon.Be explicit about
reductionto avoid surprises (default may differ among functions).