Tutorial 4: Pairwise & Embedding Similarity#

This tutorial covers two cosine-similarity APIs in braintools.metric:

  • Aligned embeddings:

    • bt.metric.cosine_similarity(predictions, targets) → per-pair similarity (…,)

    • bt.metric.cosine_distance(predictions, targets) → 1 − similarity (…,)

  • Pairwise similarity matrix:

    • pairwise_cosine_similarity(X, Y=None) → (n, m) matrix (X vs Y) or (n, n) if Y is None

Note: both functions are named cosine_similarity internally; in the public namespace, the aligned version (predictions, targets) is bound to bt.metric.cosine_similarity. To access the pairwise (matrix) variant, import it explicitly as shown below.

import jax.numpy as jnp
import braintools as bt
# Import the pairwise (matrix) version explicitly and alias it
from braintools.metric._pariwise import cosine_similarity as pairwise_cosine_similarity

1) Aligned embeddings: similarity and distance#

Use these when you have matched pairs (prediction_i, target_i) and want per-pair scores. These are scale-invariant and return values in [−1, 1]. Distance is 1 - similarity.

pred = jnp.array([[1.0, 0.0, 0.0],
                   [0.0, 1.0, 0.0],
                   [1.0, 1.0, 0.0]])
targ = jnp.array([[1.0, 0.0, 0.0],
                   [1.0, 0.0, 0.0],
                   [0.0, 1.0, 0.0]])
sim = bt.metric.cosine_similarity(pred, targ)
dist = bt.metric.cosine_distance(pred, targ)
print('similarity:', sim)
print('distance  :', dist)
similarity: [1.         0.         0.70710677]
distance  : [0.         1.         0.29289323]

Tips

  • Avoid zero vectors; if necessary, pass a small epsilon to cosine_distance.

  • For batch aggregation, reduce over the last axis when needed before loss computation.

2) Pairwise similarity matrix (X vs Y)#

Use this to compute all-pairs similarities between two sets of embeddings. For X: (n, d), Y: (m, d), the result is (n, m). With Y=None, returns (n, n) similarities within X.

X = jnp.array([[1.0, 0.0, 0.0],
               [0.0, 1.0, 0.0],
               [1.0, 1.0, 0.0]])
Y = jnp.array([[1.0, 1.0, 1.0],
               [0.0, 0.0, 1.0]])
S = pairwise_cosine_similarity(X, Y)
print('pairwise shape:', S.shape)
print(S)
# Within-set similarities (X vs X)
S_xx = pairwise_cosine_similarity(X)
print('within shape:', S_xx.shape)
pairwise shape: (3, 2)
[[0.57735026 0.        ]
 [0.57735026 0.        ]
 [0.8164966  0.        ]]
within shape: (3, 3)

Performance notes

  • Pairwise matrices can be large: (n, m) memory scales linearly in both dimensions.

  • For very large sets, consider batching queries or candidates to keep memory under control.

  • JIT-compile hot paths with static shapes when possible.

3) Simple retrieval example (top‑k)#

Compute similarities between queries and items, then take top‑k indices.

queries = jnp.array([[1.0, 0.0, 0.0],
                      [0.0, 1.0, 1.0]])
items   = jnp.array([[1.0, 0.0, 0.0],
                      [0.0, 1.0, 0.0],
                      [0.0, 1.0, 1.0]])
S_qi = pairwise_cosine_similarity(queries, items)  # (n_query, n_item)
# Top‑k via argsort (descending)
topk = 2
topk_idx = jnp.argsort(-S_qi, axis=1)[:, :topk]
print('top‑k indices per query:', topk_idx)
print('top‑k sims per query  :', jnp.take_along_axis(S_qi, topk_idx, axis=1))
top‑k indices per query: [[0 1]
 [2 1]]
top‑k sims per query  : [[1.         0.        ]
 [1.0000001  0.70710677]]

4) Choosing the right API#

  • Use bt.metric.cosine_similarity / cosine_distance for aligned pairs (same shape).

  • Use pairwise_cosine_similarity to build (n, m) similarity matrices for retrieval/matching.

  • Normalize inputs if needed; cosine metrics compare directions, not magnitudes.