Tutorial 3: Ranking for Learning-to-Rank#

This tutorial shows how to use the listwise ranking loss in braintools.metric with masking and reduction options. It is suited for information retrieval, recommendation, and other Learning-to-Rank tasks.

Covered API:

  • bt.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=jnp.mean)

    • where: boolean mask for valid items (padding handling)

    • weights: per-item weights

    • reduce_fn: jnp.mean, jnp.sum, or None (unreduced)

import jax.numpy as jnp
import braintools as bt

1) Basic usage (single list)#

logits are scores to be ranked; labels are non-negative relevances (e.g., binary relevance or graded). The loss operates on the last dimension.

# One list of 4 items
logits = jnp.array([2.0, 1.0, 0.5, 0.2])
labels = jnp.array([1.0, 0.0, 0.0, 0.0])  # item 0 is most relevant
loss = bt.metric.ranking_softmax_loss(logits, labels)
print(loss)  # scalar (default reduce_fn=jnp.mean)
0.5632142

2) Batched lists with masks (variable lengths)#

Use where to ignore padded items. It must be a boolean array with the same shape as logits and labels.

# Two lists, padded to length 5
logits = jnp.array([[2.0, 1.0, 0.5, -1.0,  0.0],
                     [0.8, 0.3, 1.2, -2.0, -1.0]])
labels = jnp.array([[1.0, 0.0, 0.0, 0.0, 0.0],
                     [0.0, 0.0, 1.0, 0.0, 0.0]])
# First list has 4 valid items; second has 3 valid items
where  = jnp.array([[ True,  True,  True,  True, False],
                    [ True,  True,  True, False, False]])

# Default reduce_fn=jnp.mean -> scalar over batch
loss_mean = bt.metric.ranking_softmax_loss(logits, labels, where=where)
print('Mean loss (scalar):', loss_mean)

# Unreduced per-list losses
loss_per_list = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('Per-list loss:', loss_per_list)  # shape (batch,)
Mean loss (scalar): 0.6130267
Per-list loss: [0.49518192 0.73087144]

3) Reductions: sum vs mean vs none#

  • reduce_fn=jnp.sum: sum across the batch

  • reduce_fn=jnp.mean: average across the batch (default)

  • reduce_fn=None: return unreduced per-batch values

When there are no valid items (mask all-False) and inputs contain no NaN, the mean reduction returns 0.0 to avoid NaNs.

sum_loss  = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.sum)
mean_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=jnp.mean)
none_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, reduce_fn=None)
print('sum:',  sum_loss)
print('mean:', mean_loss)
print('none:', none_loss, ' sum(none)=', jnp.sum(none_loss))
sum: 1.2260534
mean: 0.6130267
none: [0.49518192 0.73087144]  sum(none)= 1.2260534

4) Per-item weights#

Provide weights to emphasize specific items in lists. weights must match the shape of labels/logits and is applied to the labels prior to the softmax cross-entropy.

weights = jnp.array([[1.0, 0.5, 0.5, 1.0, 0.0],
                     [1.0, 1.0, 2.0, 0.0, 0.0]])
weighted_loss = bt.metric.ranking_softmax_loss(logits, labels, where=where, weights=weights, reduce_fn=None)
print('Weighted per-list loss:', weighted_loss)
Weighted per-list loss: [0.49518192 1.4617429 ]

5) Tips & Pitfalls#

  • Shapes: operate on the last dimension (…, list_size); batch dims are leading.

  • where must be boolean and broadcastable to (…, list_size).

  • weights must match the labels/logits shape.

  • reduce_fn=None returns per-batch values; you can aggregate manually.

  • If a list has no valid items (where all-False), mean reduction returns 0.0 (when inputs have no NaN).

  • For large batches and lists, prefer JIT-compiling code paths that call this loss with static shapes.