KL divergence and cross-entropy loss
Cross-entropy loss — the go-to for classification — is actually just a simpler special case of the more general KL divergence.
Quick recap
-
KL divergence measures how different two probability distributions and are:
-
Cross-entropy is a special case of KL divergence used when the target distribution is one-hot (only one correct answer, everything else zero):
In practice, this becomes simply since is zero everywhere except the correct answer.
Since KL divergence includes an extra term (the entropy of the target distribution), minimizing cross-entropy with one-hot labels is the same as minimizing KL divergence.
When to pick each
Scenario | Target distribution | Best Loss |
---|---|---|
Standard supervised classification | One-hot (0 or 1 only) | Cross-entropy |
Label smoothing | Soft probabilities | Cross-entropy |
Model distillation, VAE training | Arbitrary probability dist. | KL divergence |
PyTorch example
Here's a quick side-by-side in PyTorch to show both in action:
import torch
import torch.nn.functional as F
logits = torch.tensor([[1.2, 0.3, -0.8], [0.2, 2.5, -0.3]])
y_true = torch.tensor([0, 1])
# Cross-entropy loss (for classification)
ce_loss = F.cross_entropy(logits, y_true)
print(f'Cross-entropy loss: {ce_loss:.4f}')
# Cross-entropy loss: 0.2912
# KL divergence loss (for matching arbitrary distributions)
soft_targets = torch.tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2]])
log_probs = F.log_softmax(logits, dim=1)
kl_loss = F.kl_div(log_probs, soft_targets, reduction='batchmean')
print(f'KL divergence loss: {kl_loss:.4f}')
# KL divergence loss: 0.1108
PyTorch expects log-probabilities (log_probs
) for KL divergence, ensuring numerical stability (just like cross-entropy internally).
Layman's takeaway
Cross-entropy is basically a simpler shortcut of KL divergence for cases where there's only one "correct" answer (one-hot targets). If your answer can only be right or wrong (0 or 1), just use cross-entropy since it's simpler and faster.
Noice!