KL divergence (discrete)Easy

KL divergence (discrete)

Background

KL divergence is the asymmetric "distance" between two probability distributions, measured in nats. It underpins a huge amount of modern ML: cross-entropy, the ELBO in variational inference, label smoothing, knowledge distillation, and the KL penalty in RLHF all reduce to it. It measures how many extra nats you pay, on average, to encode samples from pp using a code optimised for qq.

Problem statement

Implement kl_divergence(p, q), the discrete KL divergence:

KL(pq)=ipilogpiqi\text{KL}(p \,\Vert\, q) = \sum_i p_i \,\log\frac{p_i}{q_i}

Use the convention 0log0=00\cdot\log 0 = 0: positions where pi=0p_i = 0 contribute nothing. Implement that by masking out pi=0p_i = 0 before taking the log — do not add an epsilon inside the log (it biases the result).

Input

  • p — 1-D np.ndarray: a probability distribution (non-negative, sums to 1).
  • q — 1-D np.ndarray: same shape and constraints, with qi>0q_i > 0 wherever pi>0p_i > 0.

Output

Returns a floatKL(pq)\text{KL}(p \,\Vert\, q) in nats (natural log).

Examples

Example 1 — a hand-checked case

Input:  p = [0.5, 0.5], q = [0.25, 0.75]
Output: ≈ 0.1438

Explanation: KL=0.5log0.50.25+0.5log0.50.75=0.5log2+0.5log230.34660.2027=0.1438\text{KL} = 0.5\log\frac{0.5}{0.25} + 0.5\log\frac{0.5}{0.75} = 0.5\log 2 + 0.5\log\tfrac23 \approx 0.3466 - 0.2027 = 0.1438 nats.

Example 2 — a zero in p contributes nothing

Input:  p = [0.0, 0.5, 0.5], q = [0.1, 0.4, 0.5]
Output: ≈ 0.1116

Explanation: the p0=0p_0 = 0 term is dropped (the 0log0=00\cdot\log 0 = 0 convention), leaving 0.5log0.50.4+0.5log0.50.5=0.5log1.25+00.11160.5\log\frac{0.5}{0.4} + 0.5\log\frac{0.5}{0.5} = 0.5\log 1.25 + 0 \approx 0.1116 nats.

Constraints

  • KL(pp)=0\text{KL}(p \,\Vert\, p) = 0, and KL0\text{KL} \ge 0 for any valid p,qp, q (Gibbs' inequality), with equality iff p=qp = q.
  • Handle pi=0p_i = 0 via a p > 0 mask; the contribution there is exactly 0 even though log0=\log 0 = -\infty.
  • Do not add an epsilon inside the log — the zero convention is exact.
  • Use the natural log (result in nats); return a scalar float.

Notes

  • Why "divergence", not "distance". KL is asymmetric: KL(pq)KL(qp)\text{KL}(p \,\Vert\, q) \ne \text{KL}(q \,\Vert\, p) in general. The first argument is the "true" distribution, the second the "approximate" one — swapping them changes the answer (and the behaviour: forward vs reverse KL).
  • Related. Cross-entropy equals entropy of pp plus KL(pq)\text{KL}(p\Vert q); see cross-entropy gradient and label smoothing loss.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • KL(p || p) = 0 for any p
  • KL > 0 for p != q (non-negativity)
  • Diagnostic: matches the explicit formula on a hand-checked case
  • Asymmetry: KL(p||q) != KL(q||p) in general
  • Handles zero entries in p (0 * log(0) = 0 by convention)
  • Returns a scalar, not an array