kl_divergence_with_log_targets#
- class braintools.metric.kl_divergence_with_log_targets(log_predictions, log_targets)#
Compute KL divergence when both predictions and targets are in log-space.
This is a numerically stable version of KL divergence computation when both the target and predicted distributions are provided in log-space, avoiding potential underflow issues that can occur with very small probabilities.
The computation uses the log-space formula:
\[D_{KL}(P||Q) = \sum_i \exp(\log P(i)) \cdot (\log P(i) - \log Q(i))\]- Parameters:
log_predictions (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Log probabilities of the predicted distribution with shape[..., num_classes]. Must be in log-space.log_targets (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – Log probabilities of the target distribution with shape[..., num_classes]. Must be in log-space.
- Returns:
KL divergence values with shape
[...].- Return type:
Array|ndarray|bool|number|bool|int|float|complex|Quantity
Examples
>>> import jax.numpy as jnp >>> import braintools >>> # Both distributions in log-space >>> log_targets = jnp.log(jnp.array([[0.7, 0.2, 0.1]])) >>> log_preds = jnp.log(jnp.array([[0.6, 0.3, 0.1]])) >>> kl_div = braintools.metric.kl_divergence_with_log_targets(log_preds, log_targets) >>> print(f"KL divergence: {kl_div[0]:.4f}")
Notes
This function is preferred when working with very small probabilities or when both distributions are naturally available in log-space, as it provides better numerical stability.