{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 3: Ranking for Learning-to-Rank\n",
    "\n",
    "This tutorial shows how to use the listwise ranking loss in `braintools.metric`\n",
    "with masking and reduction options. It is suited for information retrieval,\n",
    "recommendation, and other Learning-to-Rank tasks.\n",
    "\n",
    "Covered API:\n",
    "- `bt.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=jnp.mean)`\n",
    "  - `where`: boolean mask for valid items (padding handling)\n",
    "  - `weights`: per-item weights\n",
    "  - `reduce_fn`: `jnp.mean`, `jnp.sum`, or `None` (unreduced)\n"
   ],
   "id": "addece12dfdf29ff"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:31:47.505099Z",
     "start_time": "2025-09-23T12:31:44.169242Z"
    }
   },
   "source": [
    "import jax.numpy as jnp\n",
    "import braintools as bt"
   ],
   "id": "d8681da74a65d471",
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1) Basic usage (single list)\n\n",
    "`logits` are scores to be ranked; `labels` are non-negative relevances (e.g.,\n",
    "binary relevance or graded). The loss operates on the last dimension."
   ],
   "id": "a834db8618baa401"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:31:47.772738Z",
     "start_time": "2025-09-23T12:31:47.517110Z"
    }
   },
   "source": [
    "# One list of 4 items\n",
    "logits = jnp.array([2.0, 1.0, 0.5, 0.2])\n",
    "labels = jnp.array([1.0, 0.0, 0.0, 0.0])  # item 0 is most relevant\n",
    "loss = bt.metric.ranking_softmax_loss(logits, labels)\n",
    "print(loss)  # scalar (default reduce_fn=jnp.mean)"
   ],
   "id": "e5b9bf82024adf2a",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5632142\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2) Batched lists with masks (variable lengths)\n\n",
    "Use `where` to ignore padded items. It must be a boolean array with the same\n",
    "shape as logits and labels."
   ],
   "id": "a94f18b5075e4c1"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:31:48.088249Z",
     "start_time": "2025-09-23T12:31:47.778748Z"
    }
   },
   "source": [
    "# Two lists, padded to length 5\n",
    "logits = jnp.array([[2.0, 1.0, 0.5, -1.0,  0.0],\n",
    "                     [0.8, 0.3, 1.2, -2.0, -1.0]])\n",
    "labels = jnp.array([[1.0, 0.0, 0.0, 0.0, 0.0],\n",
    "                     [0.0, 0.0, 1.0, 0.0, 0.0]])\n",
    "# First list has 4 valid items; second has 3 valid items\n",
    "where  = jnp.array([[ True,  True,  True,  True, False],\n",
    "                    [ True,  True,  True, False, False]])\n\n",
    "# Default reduce_fn=jnp.mean -> scalar over batch\n",
    "loss_mean = bt.metric.ranking_softmax_loss(logits, labels, where=where)\n",
    "print('Mean loss (scalar):', loss_mean)\n\n",
    "# Unreduced per-list losses\n",
    "loss_per_list = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)\n",
    "print('Per-list loss:', loss_per_list)  # shape (batch,)"
   ],
   "id": "e54dbdcad18cb2f2",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean loss (scalar): 0.6130267\n",
      "Per-list loss: [0.49518192 0.73087144]\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3) Reductions: sum vs mean vs none\n\n",
    "- `reduce_fn=jnp.sum`: sum across the batch\n",
    "- `reduce_fn=jnp.mean`: average across the batch (default)\n",
    "- `reduce_fn=None`: return unreduced per-batch values\n\n",
    "When there are no valid items (mask all-False) and inputs contain no NaN,\n",
    "the mean reduction returns 0.0 to avoid NaNs."
   ],
   "id": "320c02a0265aa31f"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:31:48.148863Z",
     "start_time": "2025-09-23T12:31:48.100265Z"
    }
   },
   "source": [
    "sum_loss  = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.sum)\n",
    "mean_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.mean)\n",
    "none_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)\n",
    "print('sum:',  sum_loss)\n",
    "print('mean:', mean_loss)\n",
    "print('none:', none_loss, ' sum(none)=', jnp.sum(none_loss))"
   ],
   "id": "38e7667bac16ced7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum: 1.2260534\n",
      "mean: 0.6130267\n",
      "none: [0.49518192 0.73087144]  sum(none)= 1.2260534\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4) Per-item weights\n\n",
    "Provide `weights` to emphasize specific items in lists. `weights` must match\n",
    "the shape of `labels`/`logits` and is applied to the labels prior to the\n",
    "softmax cross-entropy."
   ],
   "id": "aaa395d6eb0f7b57"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:31:48.172047Z",
     "start_time": "2025-09-23T12:31:48.166883Z"
    }
   },
   "source": [
    "weights = jnp.array([[1.0, 0.5, 0.5, 1.0, 0.0],\n",
    "                     [1.0, 1.0, 2.0, 0.0, 0.0]])\n",
    "weighted_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, weights=weights, reduce_fn=None)\n",
    "print('Weighted per-list loss:', weighted_loss)"
   ],
   "id": "bea740301aa1c00c",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weighted per-list loss: [0.49518192 1.4617429 ]\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5) Tips & Pitfalls\n\n",
    "- Shapes: operate on the last dimension `(…, list_size)`; batch dims are leading.\n",
    "- `where` must be boolean and broadcastable to `(…, list_size)`.\n",
    "- `weights` must match the labels/logits shape.\n",
    "- `reduce_fn=None` returns per-batch values; you can aggregate manually.\n",
    "- If a list has no valid items (`where` all-False), mean reduction returns 0.0 (when inputs have no NaN).\n",
    "- For large batches and lists, prefer JIT-compiling code paths that call this loss with static shapes."
   ],
   "id": "c4e5abf852e0b2b2"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
