{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 2: Regression Losses\n",
    "\n",
    "This tutorial shows how to use core regression losses in `braintools.metric` and when to choose each:\n",
    "\n",
    "- L1 / MAE: `absolute_error`, `l1_loss`\n",
    "- L2 / MSE: `squared_error`, `l2_loss`\n",
    "- Robust: `huber_loss`, `log_cosh`\n",
    "- Embeddings: `cosine_distance` (and `cosine_similarity`)\n",
    "\n",
    "Most functions support reductions like `'none'|'mean'|'sum'` and optional `axis` for per-sample aggregation."
   ],
   "id": "b5d7d3248eb722ba"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.370631Z",
     "start_time": "2025-09-23T12:34:31.399828Z"
    }
   },
   "source": [
    "import jax.numpy as jnp\n",
    "import braintools as bt"
   ],
   "id": "9756ef66e449c927",
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup: sample predictions and targets\n\n",
    "We'll use simple arrays for clarity; in practice these are model outputs and labels."
   ],
   "id": "6a1dc6839f1b32c7"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.467663Z",
     "start_time": "2025-09-23T12:34:34.376645Z"
    }
   },
   "source": [
    "y_pred = jnp.array([[1.0, 2.0, 3.0],\n",
    "                     [2.0, 2.5, 2.0]])\n",
    "y_true = jnp.array([[1.1, 1.9, 3.2],\n",
    "                     [2.0, 2.0, 2.0]])\n",
    "y_outlier = jnp.array([[1.0, 2.0, 10.0],\n",
    "                       [2.0, 2.5, -5.0]])  # to show robustness"
   ],
   "id": "780df0e7c7c4cfc5",
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## L1 loss (Mean Absolute Error)\n\n",
    "Use L1 when robustness to outliers is important."
   ],
   "id": "fe3377e71ae614e3"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.576793Z",
     "start_time": "2025-09-23T12:34:34.472669Z"
    }
   },
   "source": [
    "# Elementwise absolute error, then mean over last axis (per-sample MAE)\n",
    "mae_per_sample = bt.metric.absolute_error(y_pred, y_true, axis=-1, reduction='mean')\n",
    "print('MAE per sample:', mae_per_sample)\n\n",
    "# Direct L1 loss API (commonly returns mean by default)\n",
    "l1 = bt.metric.l1_loss(y_pred, y_true)\n",
    "print('l1_loss (mean):', l1)\n\n",
    "# Outlier comparison\n",
    "print('MAE with outlier:', bt.metric.absolute_error(y_outlier, y_true, axis=-1, reduction='mean'))"
   ],
   "id": "fb6681de8fbd7ebb",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE per sample: [0.13333337 0.16666667]\n",
      "l1_loss (mean): 0.9000001\n",
      "MAE with outlier: [2.3333335 2.5      ]\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## L2 loss (Mean Squared Error)\n\n",
    "Use L2 when larger errors should be penalized more heavily."
   ],
   "id": "945326fb062554f8"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.628166Z",
     "start_time": "2025-09-23T12:34:34.588452Z"
    }
   },
   "source": [
    "# Squared error mean over last axis (per-sample MSE)\n",
    "mse_per_sample = bt.metric.squared_error(y_pred, y_true, axis=-1, reduction='mean')\n",
    "print('MSE per sample:', mse_per_sample)\n\n",
    "# Direct L2 loss API\n",
    "l2 = bt.metric.l2_loss(y_pred, y_true)\n",
    "print('l2_loss (mean):', l2)\n\n",
    "# Outlier comparison\n",
    "print('MSE with outlier:', bt.metric.squared_error(y_outlier, y_true, axis=-1, reduction='mean'))"
   ],
   "id": "d0f12ef0dac2600d",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE per sample: [0.02000001 0.08333334]\n",
      "l2_loss (mean): [[0.005      0.005      0.02000001]\n",
      " [0.         0.125      0.        ]]\n",
      "MSE with outlier: [15.420001 16.416668]\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Huber loss (robust L2)\n\n",
    "Huber behaves like L2 near zero and L1 for large residuals; set `delta` to tune the transition."
   ],
   "id": "c363242fabeb3bc5"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.674726Z",
     "start_time": "2025-09-23T12:34:34.634084Z"
    }
   },
   "source": [
    "huber = bt.metric.huber_loss(y_pred, y_true, delta=1.0)\n",
    "huber_outlier = bt.metric.huber_loss(y_outlier, y_true, delta=1.0)\n",
    "print('Huber (mean):', huber)\n",
    "print('Huber with outlier (mean):', huber_outlier)"
   ],
   "id": "a54ccf01ad8b1cff",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Huber (mean): [[0.005      0.005      0.02000001]\n",
      " [0.         0.125      0.        ]]\n",
      "Huber with outlier (mean): [[5.000002e-03 5.000002e-03 6.300000e+00]\n",
      " [0.000000e+00 1.250000e-01 6.500000e+00]]\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## log-cosh (smooth robust loss)\n\n",
    "`log_cosh` is a smooth approximation to L1 that is less sensitive than L2 to outliers."
   ],
   "id": "4c78f79daecd3faf"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:34.795598Z",
     "start_time": "2025-09-23T12:34:34.679736Z"
    }
   },
   "source": [
    "lc = bt.metric.log_cosh(y_pred - y_true)\n",
    "lc_outlier = bt.metric.log_cosh(y_outlier - y_true)\n",
    "print('log-cosh (mean):', lc)\n",
    "print('log-cosh with outlier (mean):', lc_outlier)"
   ],
   "id": "fa98f097d8e079d8",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "log-cosh (mean): [[0.00499171 0.00499171 0.01986814]\n",
      " [0.         0.12011451 0.        ]]\n",
      "log-cosh with outlier (mean): [[4.99171019e-03 4.99171019e-03 6.10685444e+00]\n",
      " [0.00000000e+00 1.20114505e-01 6.30685377e+00]]\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cosine distance (1 - cosine similarity)\n\n",
    "Use for comparing directions of vectors (embeddings). Scale-invariant and bounded."
   ],
   "id": "2ff055e770c375eb"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:34:35.069001Z",
     "start_time": "2025-09-23T12:34:34.801606Z"
    }
   },
   "source": [
    "# Pairwise aligned vectors [..., D] -> [...]\n",
    "v1 = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])\n",
    "v2 = jnp.array([[0.0, 1.0], [1.0, 0.0], [1.0, -1.0]])\n",
    "cd = bt.metric.cosine_distance(v1, v2, epsilon=1e-8)\n",
    "print('Cosine distance:', cd)\n",
    "\n",
    "# Also available: cosine_similarity (aligned) and pairwise matrix version in braintools.metric.cosine_similarity (X,Y)\n",
    "cs_aligned = bt.metric.cosine_similarity(v1, v2)\n",
    "print('Cosine similarity:', cs_aligned)"
   ],
   "id": "6efa751f309814e6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cosine distance: [1. 1. 1.]\n",
      "Cosine similarity: [0. 0. 0.]\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Guidance\n\n",
    "- Prefer L1/Huber/log-cosh when outliers are present or robustness is desired.\n",
    "- Use L2/MSE for well-behaved noise where larger errors should be penalized quadratically.\n",
    "- For embeddings, normalize implicitly via cosine distance; no need to re-scale features.\n",
    "- Use `axis` to aggregate per-sample (e.g., `axis=-1`) and set `reduction` explicitly when needed."
   ],
   "id": "6c45e986af9fb898"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pitfalls\n\n",
    "- Ensure predictions and targets have the same shape for arithmetic losses.\n",
    "- For cosine metrics, avoid zero vectors or set a small `epsilon`.\n",
    "- Be explicit about `reduction` to avoid surprises (default may differ among functions)."
   ],
   "id": "1ce2b3b70c52bcfd"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
