Top-k samplingEasy

Top-k sampling

Background

Top-k sampling is a next-token decoding rule: keep only the k highest-probability tokens, renormalise them into a distribution, and sample one. It's the fixed-cardinality alternative to top-p (nucleus) sampling — you always draw from exactly k candidates regardless of how confident the model is. GPT-2's typical setting was top_k = 40; it survives today as a fallback knob and in inference engines that prefer fixed-size candidate sets for batching.

Problem statement

Implement top_k_sample(probs, k, rng):

  1. Find the indices of the k largest entries in probs.
  2. Renormalise those k probabilities to sum to 1: pi=pi/jtop-kpjp'_i = p_i / \sum_{j \in \text{top-}k} p_j.
  3. Sample one position from that renormalised distribution using rng.
  4. Return the corresponding original-vocabulary index.

Input

  • probs — 1-D np.ndarray of length V: a valid probability distribution (non-negative, sums to 1).
  • kint 1\ge 1: how many tokens to keep. If k >= V, no truncation.
  • rngnp.random.Generator: the only source of randomness to use.

Output

Returns an int: the sampled vocabulary index in [0,V)[0, V).

Examples

Example 1 — k=1 is deterministic (argmax)

Input:  probs = [0.1, 0.05, 0.7, 0.15], k = 1
Output: 2   (every time)

Explanation: with k=1 only the single highest-probability token survives, so after renormalising it has probability 1 — sampling always returns the argmax, index 2.

Example 2 — k=2 cuts the tail and renormalises

Input:  probs = [0.6, 0.3, 0.05, 0.03, 0.02], k = 2
Output: index 0 with prob ≈ 2/3, index 1 with prob ≈ 1/3 (indices 2–4: never)

Explanation: keep the top 2 (indices 0 and 1) and renormalise — 0.6 and 0.3 become 0.6/0.9 = 2/3 and 0.3/0.9 = 1/3. The long tail (indices 2–4) is removed entirely, sampled with probability 0.

Constraints

  • Select the top k by probability (np.argpartition(probs, -k)[-k:] is O(V)O(V); argsort works but is O(VlogV)O(V\log V)).
  • Renormalise the kept probabilities to sum to 1 before sampling; map the sampled position back to the original vocabulary index.
  • Use only the rng argument for randomness — never np.random global state (the determinism test relies on this).
  • k >= V means sample from the full distribution; tokens outside the top k are never sampled.

Notes

  • Top-k vs top-p. Top-k always samples from a fixed k candidates (simple, fixed compute); top-p adapts the candidate count to the model's confidence (1–2 tokens when sure, 50+ when uncertain).
  • Pipeline. Usually applied after temperature scaling reshapes the distribution; both replace the plain argmax of greedy decoding.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Returns an int in [0, V)
  • k=1 picks the argmax deterministically
  • k >= V samples from the full distribution
  • Diagnostic: k=2 cuts the long tail entirely
  • Determinism: same rng + inputs -> same output
  • Approximate empirical proportions within the kept set