Top-p (nucleus) samplingMedium

Top-p (nucleus) sampling

Background

Top-p (nucleus) sampling is the default decoding strategy for modern LLMs — it balances greedy decoding (deterministic, often dull) against full-distribution sampling (creative, occasionally bizarre). At each step it samples only from the nucleus: the smallest set of top tokens whose probabilities sum to at least p. Crucially the nucleus adapts to confidence — one or two tokens when the model is sure, dozens when it's uncertain — which is what gives it the edge over fixed-size top-k.

Problem statement

Implement top_p_sample(probs, p, rng):

  1. Sort token probabilities descending.
  2. Take the smallest prefix whose cumulative probability p\ge p — but always keep at least one token (the argmax), even if it alone already exceeds p.
  3. Renormalise that prefix to sum to 1.
  4. Sample one token from it using rng, and map back to the original index.

Input

  • probs — 1-D np.ndarray of length V: a valid probability distribution (non-negative, sums to 1).
  • pfloat in (0,1](0, 1]: the nucleus threshold.
  • rngnp.random.Generator: the only source of randomness to use.

Output

Returns an int: the sampled index in [0,V)[0, V) (original-probs space).

Examples

Example 1 — cut to the nucleus

Input:  probs = [0.6, 0.35, 0.04, 0.01], p = 0.9
Output: samples index 0 or 1 only (indices 2, 3: never)

Explanation: the descending cumulative sum is 0.6, then 0.6+0.35 = 0.95 ≥ 0.9, so the nucleus is the top 2 tokens {0, 1}. Renormalise (0.6, 0.35 → 0.632, 0.368) and sample; the long tail (indices 2, 3) is excluded.

Example 2 — always keep the top token

Input:  probs = [0.95, 0.03, 0.02], p = 0.5
Output: 0

Explanation: the top token alone has probability 0.95 ≥ 0.5, so the nucleus is just {0}. The "keep at least one" rule guarantees it survives (a naïve cutoff might keep zero tokens); with a single token, sampling is deterministic → index 0.

Constraints

  • Sort descending, take a cumulative sum, and find the cutoff (np.searchsorted(cumulative, p) + 1 is one call); clamp the kept count to at least 1.
  • Renormalise the nucleus to sum to 1 before sampling; map the sampled position back to the original index via the sort order.
  • Use only the rng argument — never np.random global state (the determinism test relies on this).
  • p ≈ 1.0 samples from (essentially) the full distribution; a tiny p collapses to the argmax.

Notes

  • Top-p vs top-k. Top-p's candidate count is adaptive — it grows and shrinks with the model's confidence — whereas top-k always keeps a fixed k. Most production LLMs default to top_p ≈ 0.9.
  • Pipeline. Like top-k, it usually follows temperature scaling and replaces the plain argmax of greedy decoding.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Returns an int in [0, V)
  • p ≈ 1.0 samples from the full distribution
  • Tiny p picks the argmax
  • Cuts to the nucleus (excludes long tail)
  • Always keeps the top token even if it alone exceeds p
  • Determinism: same rng + inputs → same output