{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 1: Classification Losses\n",
    "\n",
    "This tutorial introduces the core classification losses in `braintools.metric` and how to pick and use them effectively:\n",
    "\n",
    "- Binary/multilabel: `braintools.metric.sigmoid_binary_cross_entropy`\n",
    "- Multiclass: `braintools.metric.softmax_cross_entropy`, `braintools.metric.softmax_cross_entropy_with_integer_labels`\n",
    "- Imbalanced data: `braintools.metric.sigmoid_focal_loss`\n",
    "- Regularization: `braintools.metric.smooth_labels`\n",
    "\n",
    "All examples use JAX arrays and are shape- and type-checked."
   ],
   "id": "78251f0b2ed7be9c"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:57.842185Z",
     "start_time": "2025-09-23T12:27:54.648202Z"
    }
   },
   "source": [
    "import jax.numpy as jnp\n",
    "import braintools"
   ],
   "id": "9f78fe97b19b2313",
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary / Multilabel (Sigmoid Cross-Entropy)\n\n",
    "Use when each class is an independent yes/no decision (not mutually exclusive)."
   ],
   "id": "7fd55c6e4dc027ac"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.067179Z",
     "start_time": "2025-09-23T12:27:57.847193Z"
    }
   },
   "source": [
    "# Binary/multilabel setup (elementwise BCE)\n",
    "logits = jnp.array([1.0, -1.0, 0.0])        # unnormalized logits\n",
    "labels = jnp.array([1.0,  0.0, 1.0])        # binary targets in {0,1}\n",
    "\n",
    "loss = braintools.metric.sigmoid_binary_cross_entropy(logits, labels)\n",
    "print(loss)          # per-element BCE, same shape as logits\n",
    "print(loss.mean())   # common reduction"
   ],
   "id": "eac8494697d131df",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.3132617 0.3132617 0.6931472]\n",
      "0.4398902\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tips\n",
    "- Labels must be float (0/1) or probabilities, shape-broadcastable with `logits`.\n",
    "- Use this also for multilabel multiclass problems."
   ],
   "id": "7c3efabbbea46f17"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiclass (Softmax Cross-Entropy)\n\n",
    "Use when classes are mutually exclusive."
   ],
   "id": "8dffda99007410e0"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.173845Z",
     "start_time": "2025-09-23T12:27:58.075199Z"
    }
   },
   "source": [
    "# One-hot or probability targets\n",
    "# logits shape [..., num_classes]\n",
    "logits = jnp.array([[2.0, 1.0, 0.1]])\n",
    "targets = jnp.array([[1.0, 0.0, 0.0]])  # one-hot\n",
    "\n",
    "loss = braintools.metric.softmax_cross_entropy(logits, targets)\n",
    "print(loss)  # shape [...]"
   ],
   "id": "9206274e4394d29b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.41702995]\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.366477Z",
     "start_time": "2025-09-23T12:27:58.181396Z"
    }
   },
   "source": [
    "# Integer labels (preferred for single-label classification)\n",
    "logits = jnp.array([[2.0, 1.0, 0.1]])\n",
    "labels = jnp.array([0])  # class index\n",
    "\n",
    "loss = braintools.metric.softmax_cross_entropy_with_integer_labels(logits, labels)\n",
    "print(loss)"
   ],
   "id": "eb80a83204afa8cd",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.41702995]\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notes\n",
    "- `softmax_cross_entropy` expects float targets; `*_with_integer_labels` expects integer labels.\n",
    "- Shapes: `logits[..., C]`, `labels[...]` or `targets[..., C]`."
   ],
   "id": "ea2d5a2e956acaf0"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Focal Loss (Imbalanced Data)\n\n",
    "Use focal loss to focus learning on hard, misclassified examples in imbalanced settings. For multilabel/binary, use `sigmoid_focal_loss`."
   ],
   "id": "6bd1f8c2660d7a40"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.665680Z",
     "start_time": "2025-09-23T12:27:58.373010Z"
    }
   },
   "source": [
    "# Imbalanced binary classification example\n",
    "logits = jnp.array([2.0, -1.0, 0.5, -2.0])\n",
    "labels = jnp.array([1.0, 0.0, 1.0, 0.0])\n",
    "\n",
    "# Alpha balances positive vs negative; gamma focuses on hard examples\n",
    "loss = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=0.25, gamma=2.0)\n",
    "print(loss)\n",
    "print(loss.mean())\n",
    "\n",
    "# Unweighted focal (no class weighting)\n",
    "loss_unweighted = braintools.metric.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)\n",
    "print(loss_unweighted)"
   ],
   "id": "863f3ad4b0c81030",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.00045089 0.01699354 0.01689337 0.00135267]\n",
      "0.008922619\n",
      "[0.00180356 0.02265805 0.06757348 0.00180356]\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Guidance\n",
    "- Choose `alpha` near the positive-class frequency (e.g., 0.25).\n",
    "- Increase `gamma` (1–5) to emphasize harder examples more."
   ],
   "id": "6933e7bcc561e090"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label Smoothing (Regularization)\n\n",
    "Smoothing prevents overconfidence by blending one-hot labels with a uniform distribution. Combine with softmax CE."
   ],
   "id": "7f78046874fa32a3"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.719649Z",
     "start_time": "2025-09-23T12:27:58.670807Z"
    }
   },
   "source": [
    "# One-hot labels [..., C]\n",
    "labels_one_hot = jnp.array([[1.0, 0.0, 0.0]])\n",
    "smoothed = braintools.metric.smooth_labels(labels_one_hot, alpha=0.1)\n",
    "\n",
    "logits = jnp.array([[2.0, 1.0, 0.5]])\n",
    "loss = braintools.metric.softmax_cross_entropy(logits, smoothed)\n",
    "print(loss)\n",
    "print(smoothed)"
   ],
   "id": "6844628a0a1e66a3",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.54770213]\n",
      "[[0.93333334 0.03333334 0.03333334]]\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Rules of thumb\n",
    "- `alpha` in [0.05, 0.2] is common; larger values smooth more.\n",
    "- Improves calibration and robustness; may slightly lower peak accuracy if overused."
   ],
   "id": "a388b597f975280c"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Which Loss Should I Use?\n\n",
    "- Binary/multilabel tasks → `sigmoid_binary_cross_entropy`\n",
    "- Multiclass single-label → `softmax_cross_entropy_with_integer_labels`\n",
    "- Heavily imbalanced (binary/multilabel) → `sigmoid_focal_loss`\n",
    "- Overconfident models → `smooth_labels` + softmax CE"
   ],
   "id": "570b88ca390d19ff"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Common Pitfalls\n\n",
    "- Do not feed integer labels to `softmax_cross_entropy` (use the integer-label variant).\n",
    "- For multilabel problems, use sigmoid-based losses (not softmax).\n",
    "- Always match shapes: `logits[..., C]` with one-hot targets `[..., C]` or integer `[...]`."
   ],
   "id": "24798e008625e0d6"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extras: NLL with Log-Probabilities\n\n",
    "If you already have log-probabilities, use `nll_loss` directly."
   ],
   "id": "2ee2d98a0645b6ac"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:27:58.832500Z",
     "start_time": "2025-09-23T12:27:58.725655Z"
    }
   },
   "source": [
    "log_probs = jnp.log(jnp.array([0.1, 0.7, 0.2]))\n",
    "target = 1\n",
    "print(braintools.metric.nll_loss(log_probs, target))"
   ],
   "id": "9ee09e0814b9fb0f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.35667497\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n\n",
    "See also\n",
    "- API reference: `braintools.metric` → classification functions\n",
    "- Ranking and regression losses are covered in separate tutorials."
   ],
   "id": "47fdb881cf96ceda"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
