Beam search decoding
Background
Beam search is the standard decoding strategy for sequence models when you want high-probability output without the cost of exhaustive search. Instead of greedily keeping one token (greedy decoding) or exploring everything, it keeps the beam_width most probable partial sequences at each step, expands them, and prunes back to the top beam_width. It is a middle ground: more thorough than greedy, far cheaper than brute force.
Problem statement
Implement beam_search(start_logprobs, trans_logprobs, steps, beam_width) over a first-order Markov model and return (best_sequence, best_score) — the highest total log-probability sequence found.
- Score of a sequence =
start_logprobs[t0] + sum_t trans_logprobs[prev, next]. - Initialize candidates with each first token, keep the top
beam_width. - For each remaining step, expand every beam by all next tokens, then keep the top
beam_widthby cumulative score.
Input
start_logprobs—np.ndarray(vocab,), log-prob of the first token.trans_logprobs—np.ndarray(vocab, vocab),trans_logprobs[i, j]= log P(next=j | cur=i).steps—int, total sequence length ().beam_width—int, beams kept per step.
Output
A tuple (best_sequence, best_score):
best_sequence—list[int]of lengthsteps.best_score—float, its total log-probability.
Examples
Example 1
start = log([0.6, 0.4])
trans = log([[0.7, 0.3], [0.4, 0.6]])
steps = 2, beam_width = 2
Output: ([0, 0], log(0.6) + log(0.7)) # 0.42, the most likely length-2 path
Explanation: among [0,0]=0.42, [1,1]=0.24, [0,1]=0.18, [1,0]=0.16, the sequence [0,0] has the highest probability.
Constraints
- Work in log-space: add log-probabilities rather than multiplying probabilities.
- Prune to the top
beam_widthcandidates after every expansion. beam_width = 1reduces to greedy decoding; a large enough beam recovers the exact best sequence.
Notes
- Beam search is not guaranteed optimal — a high-scoring prefix can crowd out the globally best sequence — but a wider beam shrinks that gap toward the exact (Viterbi) answer.
- Practical decoders add length normalization so long sequences are not unfairly penalized by summing many negative log-probs.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Reference example
- •Wide beam matches brute-force optimum
- •beam_width = 1 is greedy decoding
- •Returned sequence has the requested length
- •Score equals the sequence's summed log-probabilities