KL divergence (discrete)
Background
KL divergence is the asymmetric "distance" between two probability distributions, measured in nats. It underpins a huge amount of modern ML: cross-entropy, the ELBO in variational inference, label smoothing, knowledge distillation, and the KL penalty in RLHF all reduce to it. It measures how many extra nats you pay, on average, to encode samples from using a code optimised for .
Problem statement
Implement kl_divergence(p, q), the discrete KL divergence:
Use the convention : positions where contribute nothing. Implement that by masking out before taking the log — do not add an epsilon inside the log (it biases the result).
Input
p— 1-Dnp.ndarray: a probability distribution (non-negative, sums to 1).q— 1-Dnp.ndarray: same shape and constraints, with wherever .
Output
Returns a float — in nats (natural log).
Examples
Example 1 — a hand-checked case
Input: p = [0.5, 0.5], q = [0.25, 0.75]
Output: ≈ 0.1438
Explanation: nats.
Example 2 — a zero in p contributes nothing
Input: p = [0.0, 0.5, 0.5], q = [0.1, 0.4, 0.5]
Output: ≈ 0.1116
Explanation: the term is dropped (the convention), leaving nats.
Constraints
- , and for any valid (Gibbs' inequality), with equality iff .
- Handle via a
p > 0mask; the contribution there is exactly 0 even though . - Do not add an epsilon inside the log — the zero convention is exact.
- Use the natural log (result in nats); return a scalar
float.
Notes
- Why "divergence", not "distance". KL is asymmetric: in general. The first argument is the "true" distribution, the second the "approximate" one — swapping them changes the answer (and the behaviour: forward vs reverse KL).
- Related. Cross-entropy equals entropy of plus ; see cross-entropy gradient and label smoothing loss.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •KL(p || p) = 0 for any p
- •KL > 0 for p != q (non-negativity)
- •Diagnostic: matches the explicit formula on a hand-checked case
- •Asymmetry: KL(p||q) != KL(q||p) in general
- •Handles zero entries in p (0 * log(0) = 0 by convention)
- •Returns a scalar, not an array