Top-p (nucleus) sampling
Background
Top-p (nucleus) sampling is the default decoding strategy for modern LLMs — it balances greedy decoding (deterministic, often dull) against full-distribution sampling (creative, occasionally bizarre). At each step it samples only from the nucleus: the smallest set of top tokens whose probabilities sum to at least p. Crucially the nucleus adapts to confidence — one or two tokens when the model is sure, dozens when it's uncertain — which is what gives it the edge over fixed-size top-k.
Problem statement
Implement top_p_sample(probs, p, rng):
- Sort token probabilities descending.
- Take the smallest prefix whose cumulative probability — but always keep at least one token (the argmax), even if it alone already exceeds
p. - Renormalise that prefix to sum to 1.
- Sample one token from it using
rng, and map back to the original index.
Input
probs— 1-Dnp.ndarrayof lengthV: a valid probability distribution (non-negative, sums to 1).p—floatin : the nucleus threshold.rng—np.random.Generator: the only source of randomness to use.
Output
Returns an int: the sampled index in (original-probs space).
Examples
Example 1 — cut to the nucleus
Input: probs = [0.6, 0.35, 0.04, 0.01], p = 0.9
Output: samples index 0 or 1 only (indices 2, 3: never)
Explanation: the descending cumulative sum is 0.6, then 0.6+0.35 = 0.95 ≥ 0.9, so the nucleus is the top 2 tokens {0, 1}. Renormalise (0.6, 0.35 → 0.632, 0.368) and sample; the long tail (indices 2, 3) is excluded.
Example 2 — always keep the top token
Input: probs = [0.95, 0.03, 0.02], p = 0.5
Output: 0
Explanation: the top token alone has probability 0.95 ≥ 0.5, so the nucleus is just {0}. The "keep at least one" rule guarantees it survives (a naïve cutoff might keep zero tokens); with a single token, sampling is deterministic → index 0.
Constraints
- Sort descending, take a cumulative sum, and find the cutoff (
np.searchsorted(cumulative, p) + 1is one call); clamp the kept count to at least 1. - Renormalise the nucleus to sum to 1 before sampling; map the sampled position back to the original index via the sort order.
- Use only the
rngargument — nevernp.randomglobal state (the determinism test relies on this). p ≈ 1.0samples from (essentially) the full distribution; a tinypcollapses to the argmax.
Notes
- Top-p vs top-k. Top-p's candidate count is adaptive — it grows and shrinks with the model's confidence — whereas top-k always keeps a fixed
k. Most production LLMs default totop_p ≈ 0.9. - Pipeline. Like top-k, it usually follows temperature scaling and replaces the plain argmax of greedy decoding.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Returns an int in [0, V)
- •p ≈ 1.0 samples from the full distribution
- •Tiny p picks the argmax
- •Cuts to the nucleus (excludes long tail)
- •Always keeps the top token even if it alone exceeds p
- •Determinism: same rng + inputs → same output