kl_divergence

Contents

kl_divergence#

class braintools.metric.kl_divergence(log_predictions, targets)#

Compute the Kullback-Leibler divergence (relative entropy) loss.

KL divergence measures the information lost when approximating the target distribution with the predicted distribution. It quantifies how much one probability distribution differs from another.

The KL divergence is defined as:

\[D_{KL}(P||Q) = \sum_i P(i) \log\frac{P(i)}{Q(i)}\]

where P is the target distribution and Q is the predicted distribution.

Parameters:
  • log_predictions (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Log probabilities of the predicted distribution with shape [..., num_classes]. Must be in log-space to avoid numerical underflow issues.

  • targets (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Probabilities of the target distribution with shape [..., num_classes]. Values should be non-negative and sum to 1 along the last axis. Must be strictly positive where non-zero.

Returns:

KL divergence values with shape [...], measuring the divergence between target and predicted distributions.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity

Examples

>>> import jax.numpy as jnp
>>> import braintools
>>> # Target and predicted distributions
>>> targets = jnp.array([[0.7, 0.2, 0.1]])
>>> log_preds = jnp.log(jnp.array([[0.6, 0.3, 0.1]]))
>>> kl_div = braintools.metric.kl_divergence(log_preds, targets)
>>> print(f"KL divergence: {kl_div[0]:.4f}")

Notes

KL divergence is not symmetric: KL(P||Q) ≠ KL(Q||P). It measures the information lost when using Q to approximate P. The function handles zero probabilities by setting the corresponding terms to zero.

References