{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 4: Pairwise & Embedding Similarity\n",
    "\n",
    "This tutorial covers two cosine-similarity APIs in `braintools.metric`:\n",
    "\n",
    "- Aligned embeddings:\n",
    "  - `bt.metric.cosine_similarity(predictions, targets)` → per-pair similarity (…,)\n",
    "  - `bt.metric.cosine_distance(predictions, targets)`   → 1 − similarity (…,)\n",
    "- Pairwise similarity matrix:\n",
    "  - `pairwise_cosine_similarity(X, Y=None)` → (n, m) matrix (X vs Y) or (n, n) if Y is None\n",
    "\n",
    "Note: both functions are named `cosine_similarity` internally; in the public namespace, the\n",
    "aligned version (`predictions, targets`) is bound to `bt.metric.cosine_similarity`. To access\n",
    "the pairwise (matrix) variant, import it explicitly as shown below."
   ],
   "id": "aa48a33118193195"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:37:54.115651Z",
     "start_time": "2025-09-23T12:37:51.212637Z"
    }
   },
   "source": [
    "import jax.numpy as jnp\n",
    "import braintools as bt\n",
    "# Import the pairwise (matrix) version explicitly and alias it\n",
    "from braintools.metric._pariwise import cosine_similarity as pairwise_cosine_similarity"
   ],
   "id": "8127cc971b34b80f",
   "outputs": [],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1) Aligned embeddings: similarity and distance\n\n",
    "Use these when you have matched pairs `(prediction_i, target_i)` and want per-pair scores.\n",
    "These are scale-invariant and return values in [−1, 1]. Distance is `1 - similarity`."
   ],
   "id": "ba21ce096030352a"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:37:54.506892Z",
     "start_time": "2025-09-23T12:37:54.121577Z"
    }
   },
   "source": [
    "pred = jnp.array([[1.0, 0.0, 0.0],\n",
    "                   [0.0, 1.0, 0.0],\n",
    "                   [1.0, 1.0, 0.0]])\n",
    "targ = jnp.array([[1.0, 0.0, 0.0],\n",
    "                   [1.0, 0.0, 0.0],\n",
    "                   [0.0, 1.0, 0.0]])\n",
    "sim = bt.metric.cosine_similarity(pred, targ)\n",
    "dist = bt.metric.cosine_distance(pred, targ)\n",
    "print('similarity:', sim)\n",
    "print('distance  :', dist)"
   ],
   "id": "a885f94a2f22d6c0",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "similarity: [1.         0.         0.70710677]\n",
      "distance  : [0.         1.         0.29289323]\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tips\n",
    "- Avoid zero vectors; if necessary, pass a small `epsilon` to `cosine_distance`.\n",
    "- For batch aggregation, reduce over the last axis when needed before loss computation."
   ],
   "id": "ae21396269b2cd40"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2) Pairwise similarity matrix (X vs Y)\n\n",
    "Use this to compute all-pairs similarities between two sets of embeddings.\n",
    "For `X: (n, d)`, `Y: (m, d)`, the result is `(n, m)`. With `Y=None`, returns `(n, n)` similarities within `X`."
   ],
   "id": "d851189896d34125"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:37:54.866918Z",
     "start_time": "2025-09-23T12:37:54.512466Z"
    }
   },
   "source": [
    "X = jnp.array([[1.0, 0.0, 0.0],\n",
    "               [0.0, 1.0, 0.0],\n",
    "               [1.0, 1.0, 0.0]])\n",
    "Y = jnp.array([[1.0, 1.0, 1.0],\n",
    "               [0.0, 0.0, 1.0]])\n",
    "S = pairwise_cosine_similarity(X, Y)\n",
    "print('pairwise shape:', S.shape)\n",
    "print(S)\n",
    "# Within-set similarities (X vs X)\n",
    "S_xx = pairwise_cosine_similarity(X)\n",
    "print('within shape:', S_xx.shape)"
   ],
   "id": "b9d1fea1de0d3aec",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pairwise shape: (3, 2)\n",
      "[[0.57735026 0.        ]\n",
      " [0.57735026 0.        ]\n",
      " [0.8164966  0.        ]]\n",
      "within shape: (3, 3)\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Performance notes\n",
    "- Pairwise matrices can be large: `(n, m)` memory scales linearly in both dimensions.\n",
    "- For very large sets, consider batching queries or candidates to keep memory under control.\n",
    "- JIT-compile hot paths with static shapes when possible."
   ],
   "id": "bf2eb32f69f1b1f7"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3) Simple retrieval example (top‑k)\n\n",
    "Compute similarities between `queries` and `items`, then take top‑k indices."
   ],
   "id": "5450716c89348a0b"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-23T12:37:55.083301Z",
     "start_time": "2025-09-23T12:37:54.873114Z"
    }
   },
   "source": [
    "queries = jnp.array([[1.0, 0.0, 0.0],\n",
    "                      [0.0, 1.0, 1.0]])\n",
    "items   = jnp.array([[1.0, 0.0, 0.0],\n",
    "                      [0.0, 1.0, 0.0],\n",
    "                      [0.0, 1.0, 1.0]])\n",
    "S_qi = pairwise_cosine_similarity(queries, items)  # (n_query, n_item)\n",
    "# Top‑k via argsort (descending)\n",
    "topk = 2\n",
    "topk_idx = jnp.argsort(-S_qi, axis=1)[:, :topk]\n",
    "print('top‑k indices per query:', topk_idx)\n",
    "print('top‑k sims per query  :', jnp.take_along_axis(S_qi, topk_idx, axis=1))"
   ],
   "id": "b83c1d959180d018",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "top‑k indices per query: [[0 1]\n",
      " [2 1]]\n",
      "top‑k sims per query  : [[1.         0.        ]\n",
      " [1.0000001  0.70710677]]\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4) Choosing the right API\n\n",
    "- Use `bt.metric.cosine_similarity / cosine_distance` for aligned pairs (same shape).\n",
    "- Use `pairwise_cosine_similarity` to build `(n, m)` similarity matrices for retrieval/matching.\n",
    "- Normalize inputs if needed; cosine metrics compare directions, not magnitudes."
   ],
   "id": "2025fdeb06045e30"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
