Top-k sampling
Background
Top-k sampling is a next-token decoding rule: keep only the k highest-probability tokens, renormalise them into a distribution, and sample one. It's the fixed-cardinality alternative to top-p (nucleus) sampling — you always draw from exactly k candidates regardless of how confident the model is. GPT-2's typical setting was top_k = 40; it survives today as a fallback knob and in inference engines that prefer fixed-size candidate sets for batching.
Problem statement
Implement top_k_sample(probs, k, rng):
- Find the indices of the
klargest entries inprobs. - Renormalise those
kprobabilities to sum to 1: . - Sample one position from that renormalised distribution using
rng. - Return the corresponding original-vocabulary index.
Input
probs— 1-Dnp.ndarrayof lengthV: a valid probability distribution (non-negative, sums to 1).k—int: how many tokens to keep. Ifk >= V, no truncation.rng—np.random.Generator: the only source of randomness to use.
Output
Returns an int: the sampled vocabulary index in .
Examples
Example 1 — k=1 is deterministic (argmax)
Input: probs = [0.1, 0.05, 0.7, 0.15], k = 1
Output: 2 (every time)
Explanation: with k=1 only the single highest-probability token survives, so after renormalising it has probability 1 — sampling always returns the argmax, index 2.
Example 2 — k=2 cuts the tail and renormalises
Input: probs = [0.6, 0.3, 0.05, 0.03, 0.02], k = 2
Output: index 0 with prob ≈ 2/3, index 1 with prob ≈ 1/3 (indices 2–4: never)
Explanation: keep the top 2 (indices 0 and 1) and renormalise — 0.6 and 0.3 become 0.6/0.9 = 2/3 and 0.3/0.9 = 1/3. The long tail (indices 2–4) is removed entirely, sampled with probability 0.
Constraints
- Select the top
kby probability (np.argpartition(probs, -k)[-k:]is ;argsortworks but is ). - Renormalise the kept probabilities to sum to 1 before sampling; map the sampled position back to the original vocabulary index.
- Use only the
rngargument for randomness — nevernp.randomglobal state (the determinism test relies on this). k >= Vmeans sample from the full distribution; tokens outside the topkare never sampled.
Notes
- Top-k vs top-p. Top-k always samples from a fixed
kcandidates (simple, fixed compute); top-p adapts the candidate count to the model's confidence (1–2 tokens when sure, 50+ when uncertain). - Pipeline. Usually applied after temperature scaling reshapes the distribution; both replace the plain argmax of greedy decoding.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Returns an int in [0, V)
- •k=1 picks the argmax deterministically
- •k >= V samples from the full distribution
- •Diagnostic: k=2 cuts the long tail entirely
- •Determinism: same rng + inputs -> same output
- •Approximate empirical proportions within the kept set