ranking_softmax_loss

ranking_softmax_loss#

class braintools.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=<function mean>)#

Compute ranking softmax loss for learning-to-rank applications.

Calculates a differentiable ranking loss that measures the cost of a ranking induced by item scores compared to ground truth relevance labels. This loss is particularly effective for information retrieval, recommendation systems, and other ranking tasks where the goal is to prioritize relevant items.

The loss is computed as the negative log-likelihood of the softmax distribution over items, weighted by their relevance labels:

\[\ell(s, y) = -\sum_{i=1}^{n} y_i \log \frac{\exp(s_i)}{\sum_{j=1}^{n} \exp(s_j)}\]

where \(s_i\) are the logit scores, \(y_i\) are the relevance labels, and \(n\) is the number of items in the list.

Parameters:
  • logits (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Predicted scores for each item with shape (..., list_size). Higher scores should indicate higher relevance. The function operates on the last dimension, treating leading dimensions as batch dimensions.

  • labels (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Ground truth relevance labels with shape (..., list_size). Typically non-negative values where higher values indicate greater relevance. Labels are automatically normalized by the softmax operation.

  • where (Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Boolean mask with shape (..., list_size) indicating valid items. Items where where is False are excluded from loss computation. This is useful for handling variable-length lists or missing data.

  • weights (Array | ndarray | bool | number | bool | int | float | complex | Quantity | None) – Per-item weights with shape (..., list_size) for emphasizing certain items in the loss calculation. Applied to labels before computing the softmax cross-entropy.

  • reduce_fn (Callable[..., Array | ndarray | bool | number | bool | int | float | complex | Quantity] | None) –

    Function to reduce loss values across batch dimensions. Common choices:

    • jax.numpy.mean (default): Average loss across batches

    • jax.numpy.sum: Sum loss across batches

    • None: Return unreduced per-batch losses

Returns:

Ranking softmax loss. Shape depends on reduce_fn:

  • If reduce_fn is not None: scalar loss value

  • If reduce_fn is None: array with shape (batch_dims,)

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Notes

This loss function implements a probabilistic approach to ranking where:

  • Items with higher relevance labels should receive higher probability mass

  • The softmax operation ensures valid probability distributions

  • Masked items (where where is False) are effectively ignored

  • The loss is differentiable w.r.t. logits, enabling gradient-based optimization

The function handles edge cases gracefully:

  • Empty masks (all items invalid) return 0.0 instead of NaN

  • Numerical stability is maintained through log-softmax computation

  • Mixed data types are handled by casting labels to match logit precision

Examples

Basic ranking loss with single query:

>>> import jax.numpy as jnp
>>> import braintools as braintools
>>> # Scores for 3 items
>>> logits = jnp.array([2.0, 1.0, 3.0])
>>> # Relevance: item 3 most relevant, item 1 second, item 2 least
>>> labels = jnp.array([1.0, 0.0, 2.0])
>>> loss = braintools.metric.ranking_softmax_loss(logits, labels)
>>> print(f"Loss: {loss:.4f}")

Batch processing with masking:

>>> # Batch of 2 queries with 3 items each
>>> logits = jnp.array([[2.0, 1.0, 0.0], [1.0, 0.5, 1.5]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
>>> # Second query only has first 2 items valid
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> loss = braintools.metric.ranking_softmax_loss(logits, labels, where=where)
>>> print(f"Batch loss: {loss:.4f}")

Per-item weighting:

>>> weights = jnp.array([1.0, 2.0, 1.0])  # Emphasize middle item
>>> loss = braintools.metric.ranking_softmax_loss(logits[0], labels[0], weights=weights)
>>> print(f"Weighted loss: {loss:.4f}")

Unreduced losses for analysis:

>>> batch_losses = braintools.metric.ranking_softmax_loss(
...     logits, labels, where=where, reduce_fn=None
... )
>>> print(f"Individual losses: {batch_losses}")

See also

jax.nn.log_softmax

Underlying log-softmax computation

jax.nn.softmax_cross_entropy

Related cross-entropy function

braintools.metric.sigmoid_binary_cross_entropy

Alternative for binary relevance

References