convex_kl_divergence#
- class braintools.metric.convex_kl_divergence(log_predictions, targets)#
Compute a convex version of the Kullback-Leibler divergence loss.
This function computes a modified KL divergence that is jointly convex in both the target probabilities and the predicted log probabilities. The standard KL divergence is convex only in the predicted distribution.
The convex KL divergence is defined as:
\[D_{convex}(P||Q) = D_{KL}(P||Q) + \sum_i (Q(i) - P(i))\]where the second term makes the function convex in both arguments.
- Parameters:
log_predictions (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Log probabilities of the predicted distribution with shape[..., num_classes]. Must be in log-space.targets (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Probabilities of the target distribution with shape[..., num_classes]. Values should be non-negative and sum to 1 along the last axis.
- Returns:
Convex KL divergence values with shape
[...].- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> targets = jnp.array([[0.7, 0.2, 0.1]]) >>> log_preds = jnp.log(jnp.array([[0.6, 0.3, 0.1]])) >>> conv_kl = braintools.metric.convex_kl_divergence(log_preds, targets) >>> print(f"Convex KL divergence: {conv_kl[0]:.4f}")
Notes
The convex property can be beneficial for optimization algorithms that rely on convexity guarantees, though it changes the semantic meaning compared to standard KL divergence.
References