Greedy decoding loop
Background
Greedy decoding is the simplest autoregressive generation strategy — at each step, ask the model for next-token logits and just take the most likely token. It is the baseline every other decoding rule (temperature, top-k, top-p) is measured against, and the right default when you want determinism: classifier outputs, code completion, structured-output extraction. Its pathologies — repetition loops, no diversity — are exactly why the sampling strategies exist.
Problem statement
Implement greedy_decode(predict_fn, prefix, max_new_tokens). Repeat max_new_tokens times:
logits = predict_fn(current_ids) # 1-D vector of length V
next_id = argmax(logits) # most-likely token
current_ids.append(next_id)
Start from a copy of prefix and return the full sequence. Cast each argmax to a python int before appending.
Input
predict_fn— callable:predict_fn(ids: list[int]) -> np.ndarrayof shape(V,)(next-token logits).prefix—list[int]: the starting token ids (the prompt).max_new_tokens—int: how many new tokens to generate.
Output
Returns a list[int] of length len(prefix) + max_new_tokens — the prefix followed by the generated tokens.
Examples
Example 1 — constant logits repeat the same token
Input: predict_fn always returns [1.0, 2.0, 5.0, 3.0]; prefix = [0], max_new_tokens = 4
Output: [0, 2, 2, 2, 2]
Explanation: argmax([1,2,5,3]) = 2 at every step, so the decoder emits token 2 four times after the prefix.
Example 2 — following a transition cycle
Input: predict_fn picks argmax = (last_token + 1) % 4; prefix = [0], max_new_tokens = 6
Output: [0, 1, 2, 3, 0, 1, 2]
Explanation: each step the next token is one greater mod 4, so decoding deterministically walks the 4-cycle . With a fixed predict_fn, greedy decoding is fully reproducible.
Constraints
- Copy
prefix(e.g.list(prefix)) before appending — do not mutate the caller's list. The prefix must appear unchanged at the start of the output:out[:len(prefix)] == list(prefix). predict_fnis called exactlymax_new_tokenstimes, each time with the prefix grown by one token (lengthslen(prefix), len(prefix)+1, …).- Cast
np.argmax(...)to a pythonintbefore appending (numpy ints confuse downstream==comparisons). max_new_tokens = 0returns the prefix unchanged.
Notes
- Why determinism. Greedy always takes the top token, so the same prompt yields the same output — ideal when you need reproducibility, but prone to repetition and "dead air" on creative tasks.
- Related. The sampling alternatives temper this: temperature scaling, top-k, and top-p all replace the
argmaxwith a sampled draw.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Returns prefix + max_new_tokens tokens
- •Picks the argmax at each step (constant logits -> repeated token)
- •predict_fn sees the growing prefix at each call (length increases by 1 per step)
- •Diagnostic: decoder follows a transition table deterministically
- •Empty max_new_tokens=0 returns the prefix unchanged