Tutorial 4: Pairwise & Embedding Similarity#
This tutorial covers two cosine-similarity APIs in braintools.metric:
Aligned embeddings:
bt.metric.cosine_similarity(predictions, targets)→ per-pair similarity (…,)bt.metric.cosine_distance(predictions, targets)→ 1 − similarity (…,)
Pairwise similarity matrix:
pairwise_cosine_similarity(X, Y=None)→ (n, m) matrix (X vs Y) or (n, n) if Y is None
Note: both functions are named cosine_similarity internally; in the public namespace, the
aligned version (predictions, targets) is bound to bt.metric.cosine_similarity. To access
the pairwise (matrix) variant, import it explicitly as shown below.
import jax.numpy as jnp
import braintools as bt
# Import the pairwise (matrix) version explicitly and alias it
from braintools.metric._pariwise import cosine_similarity as pairwise_cosine_similarity
1) Aligned embeddings: similarity and distance#
Use these when you have matched pairs (prediction_i, target_i) and want per-pair scores.
These are scale-invariant and return values in [−1, 1]. Distance is 1 - similarity.
pred = jnp.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 1.0, 0.0]])
targ = jnp.array([[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]])
sim = bt.metric.cosine_similarity(pred, targ)
dist = bt.metric.cosine_distance(pred, targ)
print('similarity:', sim)
print('distance :', dist)
similarity: [1. 0. 0.70710677]
distance : [0. 1. 0.29289323]
Tips
Avoid zero vectors; if necessary, pass a small
epsilontocosine_distance.For batch aggregation, reduce over the last axis when needed before loss computation.
2) Pairwise similarity matrix (X vs Y)#
Use this to compute all-pairs similarities between two sets of embeddings.
For X: (n, d), Y: (m, d), the result is (n, m). With Y=None, returns (n, n) similarities within X.
X = jnp.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 1.0, 0.0]])
Y = jnp.array([[1.0, 1.0, 1.0],
[0.0, 0.0, 1.0]])
S = pairwise_cosine_similarity(X, Y)
print('pairwise shape:', S.shape)
print(S)
# Within-set similarities (X vs X)
S_xx = pairwise_cosine_similarity(X)
print('within shape:', S_xx.shape)
pairwise shape: (3, 2)
[[0.57735026 0. ]
[0.57735026 0. ]
[0.8164966 0. ]]
within shape: (3, 3)
Performance notes
Pairwise matrices can be large:
(n, m)memory scales linearly in both dimensions.For very large sets, consider batching queries or candidates to keep memory under control.
JIT-compile hot paths with static shapes when possible.
3) Simple retrieval example (top‑k)#
Compute similarities between queries and items, then take top‑k indices.
queries = jnp.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 1.0]])
items = jnp.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 1.0]])
S_qi = pairwise_cosine_similarity(queries, items) # (n_query, n_item)
# Top‑k via argsort (descending)
topk = 2
topk_idx = jnp.argsort(-S_qi, axis=1)[:, :topk]
print('top‑k indices per query:', topk_idx)
print('top‑k sims per query :', jnp.take_along_axis(S_qi, topk_idx, axis=1))
top‑k indices per query: [[0 1]
[2 1]]
top‑k sims per query : [[1. 0. ]
[1.0000001 0.70710677]]
4) Choosing the right API#
Use
bt.metric.cosine_similarity / cosine_distancefor aligned pairs (same shape).Use
pairwise_cosine_similarityto build(n, m)similarity matrices for retrieval/matching.Normalize inputs if needed; cosine metrics compare directions, not magnitudes.