Tutorial 2: Regression Losses#

This tutorial shows how to use core regression losses in braintools.metric and when to choose each:

  • L1 / MAE: absolute_error, l1_loss

  • L2 / MSE: squared_error, l2_loss

  • Robust: huber_loss, log_cosh

  • Embeddings: cosine_distance (and cosine_similarity)

Most functions support reductions like 'none'|'mean'|'sum' and optional axis for per-sample aggregation.

import jax.numpy as jnp
import braintools as bt

Setup: sample predictions and targets#

We’ll use simple arrays for clarity; in practice these are model outputs and labels.

y_pred = jnp.array([[1.0, 2.0, 3.0],
                     [2.0, 2.5, 2.0]])
y_true = jnp.array([[1.1, 1.9, 3.2],
                     [2.0, 2.0, 2.0]])
y_outlier = jnp.array([[1.0, 2.0, 10.0],
                       [2.0, 2.5, -5.0]])  # to show robustness

L1 loss (Mean Absolute Error)#

Use L1 when robustness to outliers is important.

# Elementwise absolute error, then mean over last axis (per-sample MAE)
mae_per_sample = bt.metric.absolute_error(y_pred, y_true, axis=-1, reduction='mean')
print('MAE per sample:', mae_per_sample)

# Direct L1 loss API (commonly returns mean by default)
l1 = bt.metric.l1_loss(y_pred, y_true)
print('l1_loss (mean):', l1)

# Outlier comparison
print('MAE with outlier:', bt.metric.absolute_error(y_outlier, y_true, axis=-1, reduction='mean'))
MAE per sample: [0.13333337 0.16666667]
l1_loss (mean): 0.9000001
MAE with outlier: [2.3333335 2.5      ]

L2 loss (Mean Squared Error)#

Use L2 when larger errors should be penalized more heavily.

# Squared error mean over last axis (per-sample MSE)
mse_per_sample = bt.metric.squared_error(y_pred, y_true, axis=-1, reduction='mean')
print('MSE per sample:', mse_per_sample)

# Direct L2 loss API
l2 = bt.metric.l2_loss(y_pred, y_true)
print('l2_loss (mean):', l2)

# Outlier comparison
print('MSE with outlier:', bt.metric.squared_error(y_outlier, y_true, axis=-1, reduction='mean'))
MSE per sample: [0.02000001 0.08333334]
l2_loss (mean): [[0.005      0.005      0.02000001]
 [0.         0.125      0.        ]]
MSE with outlier: [15.420001 16.416668]

Huber loss (robust L2)#

Huber behaves like L2 near zero and L1 for large residuals; set delta to tune the transition.

huber = bt.metric.huber_loss(y_pred, y_true, delta=1.0)
huber_outlier = bt.metric.huber_loss(y_outlier, y_true, delta=1.0)
print('Huber (mean):', huber)
print('Huber with outlier (mean):', huber_outlier)
Huber (mean): [[0.005      0.005      0.02000001]
 [0.         0.125      0.        ]]
Huber with outlier (mean): [[5.000002e-03 5.000002e-03 6.300000e+00]
 [0.000000e+00 1.250000e-01 6.500000e+00]]

log-cosh (smooth robust loss)#

log_cosh is a smooth approximation to L1 that is less sensitive than L2 to outliers.

lc = bt.metric.log_cosh(y_pred - y_true)
lc_outlier = bt.metric.log_cosh(y_outlier - y_true)
print('log-cosh (mean):', lc)
print('log-cosh with outlier (mean):', lc_outlier)
log-cosh (mean): [[0.00499171 0.00499171 0.01986814]
 [0.         0.12011451 0.        ]]
log-cosh with outlier (mean): [[4.99171019e-03 4.99171019e-03 6.10685444e+00]
 [0.00000000e+00 1.20114505e-01 6.30685377e+00]]

Cosine distance (1 - cosine similarity)#

Use for comparing directions of vectors (embeddings). Scale-invariant and bounded.

# Pairwise aligned vectors [..., D] -> [...]
v1 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
v2 = jnp.array([[0.0, 1.0], [1.0, 0.0], [1.0, -1.0]])
cd = bt.metric.cosine_distance(v1, v2, epsilon=1e-8)
print('Cosine distance:', cd)

# Also available: cosine_similarity (aligned) and pairwise matrix version in braintools.metric.cosine_similarity (X,Y)
cs_aligned = bt.metric.cosine_similarity(v1, v2)
print('Cosine similarity:', cs_aligned)
Cosine distance: [1. 1. 1.]
Cosine similarity: [0. 0. 0.]

Guidance#

  • Prefer L1/Huber/log-cosh when outliers are present or robustness is desired.

  • Use L2/MSE for well-behaved noise where larger errors should be penalized quadratically.

  • For embeddings, normalize implicitly via cosine distance; no need to re-scale features.

  • Use axis to aggregate per-sample (e.g., axis=-1) and set reduction explicitly when needed.

Pitfalls#

  • Ensure predictions and targets have the same shape for arithmetic losses.

  • For cosine metrics, avoid zero vectors or set a small epsilon.

  • Be explicit about reduction to avoid surprises (default may differ among functions).