Tutorial 3: Ranking for Learning-to-Rank#
This tutorial shows how to use the listwise ranking loss in braintools.metric
with masking and reduction options. It is suited for information retrieval,
recommendation, and other Learning-to-Rank tasks.
Covered API:
bt.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=jnp.mean)where: boolean mask for valid items (padding handling)weights: per-item weightsreduce_fn:jnp.mean,jnp.sum, orNone(unreduced)
import jax.numpy as jnp
import braintools as bt
1) Basic usage (single list)#
logits are scores to be ranked; labels are non-negative relevances (e.g.,
binary relevance or graded). The loss operates on the last dimension.
# One list of 4 items
logits = jnp.array([2.0, 1.0, 0.5, 0.2])
labels = jnp.array([1.0, 0.0, 0.0, 0.0]) # item 0 is most relevant
loss = bt.metric.ranking_softmax_loss(logits, labels)
print(loss) # scalar (default reduce_fn=jnp.mean)
0.5632142
2) Batched lists with masks (variable lengths)#
Use where to ignore padded items. It must be a boolean array with the same
shape as logits and labels.
# Two lists, padded to length 5
logits = jnp.array([[2.0, 1.0, 0.5, -1.0, 0.0],
[0.8, 0.3, 1.2, -2.0, -1.0]])
labels = jnp.array([[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0]])
# First list has 4 valid items; second has 3 valid items
where = jnp.array([[ True, True, True, True, False],
[ True, True, True, False, False]])
# Default reduce_fn=jnp.mean -> scalar over batch
loss_mean = bt.metric.ranking_softmax_loss(logits, labels, where=where)
print('Mean loss (scalar):', loss_mean)
# Unreduced per-list losses
loss_per_list = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('Per-list loss:', loss_per_list) # shape (batch,)
Mean loss (scalar): 0.6130267
Per-list loss: [0.49518192 0.73087144]
3) Reductions: sum vs mean vs none#
reduce_fn=jnp.sum: sum across the batchreduce_fn=jnp.mean: average across the batch (default)reduce_fn=None: return unreduced per-batch values
When there are no valid items (mask all-False) and inputs contain no NaN, the mean reduction returns 0.0 to avoid NaNs.
sum_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.sum)
mean_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.mean)
none_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('sum:', sum_loss)
print('mean:', mean_loss)
print('none:', none_loss, ' sum(none)=', jnp.sum(none_loss))
sum: 1.2260534
mean: 0.6130267
none: [0.49518192 0.73087144] sum(none)= 1.2260534
4) Per-item weights#
Provide weights to emphasize specific items in lists. weights must match
the shape of labels/logits and is applied to the labels prior to the
softmax cross-entropy.
weights = jnp.array([[1.0, 0.5, 0.5, 1.0, 0.0],
[1.0, 1.0, 2.0, 0.0, 0.0]])
weighted_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, weights=weights, reduce_fn=None)
print('Weighted per-list loss:', weighted_loss)
Weighted per-list loss: [0.49518192 1.4617429 ]
5) Tips & Pitfalls#
Shapes: operate on the last dimension
(…, list_size); batch dims are leading.wheremust be boolean and broadcastable to(…, list_size).weightsmust match the labels/logits shape.reduce_fn=Nonereturns per-batch values; you can aggregate manually.If a list has no valid items (
whereall-False), mean reduction returns 0.0 (when inputs have no NaN).For large batches and lists, prefer JIT-compiling code paths that call this loss with static shapes.