ranking_softmax_loss#
- class braintools.metric.ranking_softmax_loss(logits, labels, *, where=None, weights=None, reduce_fn=<function mean>)#
Compute ranking softmax loss for learning-to-rank applications.
Calculates a differentiable ranking loss that measures the cost of a ranking induced by item scores compared to ground truth relevance labels. This loss is particularly effective for information retrieval, recommendation systems, and other ranking tasks where the goal is to prioritize relevant items.
The loss is computed as the negative log-likelihood of the softmax distribution over items, weighted by their relevance labels:
\[\ell(s, y) = -\sum_{i=1}^{n} y_i \log \frac{\exp(s_i)}{\sum_{j=1}^{n} \exp(s_j)}\]where \(s_i\) are the logit scores, \(y_i\) are the relevance labels, and \(n\) is the number of items in the list.
- Parameters:
logits (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Predicted scores for each item with shape(..., list_size). Higher scores should indicate higher relevance. The function operates on the last dimension, treating leading dimensions as batch dimensions.labels (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Ground truth relevance labels with shape(..., list_size). Typically non-negative values where higher values indicate greater relevance. Labels are automatically normalized by the softmax operation.where (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Boolean mask with shape(..., list_size)indicating valid items. Items wherewhereis False are excluded from loss computation. This is useful for handling variable-length lists or missing data.weights (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|None) – Per-item weights with shape(..., list_size)for emphasizing certain items in the loss calculation. Applied to labels before computing the softmax cross-entropy.reduce_fn (
Callable[...,Array|ndarray|bool|number|bool|int|float|complex|Quantity] |None) –Function to reduce loss values across batch dimensions. Common choices:
jax.numpy.mean(default): Average loss across batchesjax.numpy.sum: Sum loss across batchesNone: Return unreduced per-batch losses
- Returns:
Ranking softmax loss. Shape depends on
reduce_fn:If
reduce_fnis not None: scalar loss valueIf
reduce_fnis None: array with shape(batch_dims,)
- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Notes
This loss function implements a probabilistic approach to ranking where:
Items with higher relevance labels should receive higher probability mass
The softmax operation ensures valid probability distributions
Masked items (where
whereis False) are effectively ignoredThe loss is differentiable w.r.t. logits, enabling gradient-based optimization
The function handles edge cases gracefully:
Empty masks (all items invalid) return 0.0 instead of NaN
Numerical stability is maintained through log-softmax computation
Mixed data types are handled by casting labels to match logit precision
Examples
Basic ranking loss with single query:
>>> import jax.numpy as jnp >>> import braintools as braintools >>> # Scores for 3 items >>> logits = jnp.array([2.0, 1.0, 3.0]) >>> # Relevance: item 3 most relevant, item 1 second, item 2 least >>> labels = jnp.array([1.0, 0.0, 2.0]) >>> loss = braintools.metric.ranking_softmax_loss(logits, labels) >>> print(f"Loss: {loss:.4f}")
Batch processing with masking:
>>> # Batch of 2 queries with 3 items each >>> logits = jnp.array([[2.0, 1.0, 0.0], [1.0, 0.5, 1.5]]) >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) >>> # Second query only has first 2 items valid >>> where = jnp.array([[True, True, False], [True, True, True]]) >>> loss = braintools.metric.ranking_softmax_loss(logits, labels, where=where) >>> print(f"Batch loss: {loss:.4f}")
Per-item weighting:
>>> weights = jnp.array([1.0, 2.0, 1.0]) # Emphasize middle item >>> loss = braintools.metric.ranking_softmax_loss(logits[0], labels[0], weights=weights) >>> print(f"Weighted loss: {loss:.4f}")
Unreduced losses for analysis:
>>> batch_losses = braintools.metric.ranking_softmax_loss( ... logits, labels, where=where, reduce_fn=None ... ) >>> print(f"Individual losses: {batch_losses}")
See also
jax.nn.log_softmaxUnderlying log-softmax computation
jax.nn.softmax_cross_entropyRelated cross-entropy function
braintools.metric.sigmoid_binary_cross_entropyAlternative for binary relevance
References