Cross-entropy loss — the go-to for classification — is actually just a simpler special case of the more general KL divergence.

Quick recap

  • KL divergence measures how different two probability distributions PP and QQ are:
    DKL(PQ)=iP(i)logP(i)Q(i)D_{KL}(P \| Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}

  • Cross-entropy is a special case of KL divergence used when the target distribution PP is one-hot (only one correct answer, everything else zero):
    LCE(P,Q)=iP(i)logQ(i)L_{CE}(P, Q) = -\sum_i P(i) \log Q(i)
    In practice, this becomes simply logQy-\log Q_y since PP is zero everywhere except the correct answer.

Since KL divergence includes an extra term (the entropy of the target distribution), minimizing cross-entropy with one-hot labels is the same as minimizing KL divergence.

When to pick each

Scenario Target distribution Best Loss
Standard supervised classification One-hot (0 or 1 only) Cross-entropy
Label smoothing Soft probabilities Cross-entropy
Model distillation, VAE training Arbitrary probability dist. KL divergence

PyTorch example

Here's a quick side-by-side in PyTorch to show both in action:

import torch
import torch.nn.functional as F

logits = torch.tensor([[1.2, 0.3, -0.8], [0.2, 2.5, -0.3]])
y_true = torch.tensor([0, 1])

# Cross-entropy loss (for classification)
ce_loss = F.cross_entropy(logits, y_true)
print(f'Cross-entropy loss: {ce_loss:.4f}')
# Cross-entropy loss: 0.2912

# KL divergence loss (for matching arbitrary distributions)
soft_targets = torch.tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2]])
log_probs = F.log_softmax(logits, dim=1)
kl_loss = F.kl_div(log_probs, soft_targets, reduction='batchmean')
print(f'KL divergence loss: {kl_loss:.4f}')
# KL divergence loss: 0.1108

PyTorch expects log-probabilities (log_probs) for KL divergence, ensuring numerical stability (just like cross-entropy internally).

Layman's takeaway

Cross-entropy is basically a simpler shortcut of KL divergence for cases where there's only one "correct" answer (one-hot targets). If your answer can only be right or wrong (0 or 1), just use cross-entropy since it's simpler and faster.

Noice!