kl_divergence_with_log_targets

kl_divergence_with_log_targets#

class braintools.metric.kl_divergence_with_log_targets(log_predictions, log_targets)#

Compute KL divergence when both predictions and targets are in log-space.

This is a numerically stable version of KL divergence computation when both the target and predicted distributions are provided in log-space, avoiding potential underflow issues that can occur with very small probabilities.

The computation uses the log-space formula:

\[D_{KL}(P||Q) = \sum_i \exp(\log P(i)) \cdot (\log P(i) - \log Q(i))\]
Parameters:
  • log_predictions (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Log probabilities of the predicted distribution with shape [..., num_classes]. Must be in log-space.

  • log_targets (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Log probabilities of the target distribution with shape [..., num_classes]. Must be in log-space.

Returns:

KL divergence values with shape [...].

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Both distributions in log-space
>>> log_targets = jnp.log(jnp.array([[0.7, 0.2, 0.1]]))
>>> log_preds = jnp.log(jnp.array([[0.6, 0.3, 0.1]]))
>>> kl_div = braintools.metric.kl_divergence_with_log_targets(log_preds, log_targets)
>>> print(f"KL divergence: {kl_div[0]:.4f}")

Notes

This function is preferred when working with very small probabilities or when both distributions are naturally available in log-space, as it provides better numerical stability.