Greedy decoding loopEasy

Greedy decoding loop

Background

Greedy decoding is the simplest autoregressive generation strategy — at each step, ask the model for next-token logits and just take the most likely token. It is the baseline every other decoding rule (temperature, top-k, top-p) is measured against, and the right default when you want determinism: classifier outputs, code completion, structured-output extraction. Its pathologies — repetition loops, no diversity — are exactly why the sampling strategies exist.

Problem statement

Implement greedy_decode(predict_fn, prefix, max_new_tokens). Repeat max_new_tokens times:

logits = predict_fn(current_ids)   # 1-D vector of length V
next_id = argmax(logits)           # most-likely token
current_ids.append(next_id)

Start from a copy of prefix and return the full sequence. Cast each argmax to a python int before appending.

Input

  • predict_fn — callable: predict_fn(ids: list[int]) -> np.ndarray of shape (V,) (next-token logits).
  • prefixlist[int]: the starting token ids (the prompt).
  • max_new_tokensint: how many new tokens to generate.

Output

Returns a list[int] of length len(prefix) + max_new_tokens — the prefix followed by the generated tokens.

Examples

Example 1 — constant logits repeat the same token

Input:  predict_fn always returns [1.0, 2.0, 5.0, 3.0]; prefix = [0], max_new_tokens = 4
Output: [0, 2, 2, 2, 2]

Explanation: argmax([1,2,5,3]) = 2 at every step, so the decoder emits token 2 four times after the prefix.

Example 2 — following a transition cycle

Input:  predict_fn picks argmax = (last_token + 1) % 4; prefix = [0], max_new_tokens = 6
Output: [0, 1, 2, 3, 0, 1, 2]

Explanation: each step the next token is one greater mod 4, so decoding deterministically walks the 4-cycle 012300\to1\to2\to3\to0. With a fixed predict_fn, greedy decoding is fully reproducible.

Constraints

  • Copy prefix (e.g. list(prefix)) before appending — do not mutate the caller's list. The prefix must appear unchanged at the start of the output: out[:len(prefix)] == list(prefix).
  • predict_fn is called exactly max_new_tokens times, each time with the prefix grown by one token (lengths len(prefix), len(prefix)+1, …).
  • Cast np.argmax(...) to a python int before appending (numpy ints confuse downstream == comparisons).
  • max_new_tokens = 0 returns the prefix unchanged.

Notes

  • Why determinism. Greedy always takes the top token, so the same prompt yields the same output — ideal when you need reproducibility, but prone to repetition and "dead air" on creative tasks.
  • Related. The sampling alternatives temper this: temperature scaling, top-k, and top-p all replace the argmax with a sampled draw.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Returns prefix + max_new_tokens tokens
  • Picks the argmax at each step (constant logits -> repeated token)
  • predict_fn sees the growing prefix at each call (length increases by 1 per step)
  • Diagnostic: decoder follows a transition table deterministically
  • Empty max_new_tokens=0 returns the prefix unchanged