Multi-Head Attention (full layer)
Background
This is the full Multi-Head Attention layer — the unit stacked twelve times in GPT-2 small. It composes everything earlier in the series: a fused QKV projection, the head split (build-gpt-05), single-head scaled dot-product attention run in parallel across heads (build-gpt-03), the optional causal mask (build-gpt-04), the head recombine (build-gpt-05), and a final output projection. Multiple heads let the model attend to different relationships in parallel, then mix them back together.
Problem statement
Implement multi_head_attention(x, W_qkv, W_o, n_head, mask=None) in five steps:
- Project:
qkv = x @ W_qkv, shape(B, T, 3C), then split intoQ, K, V(each(B, T, C)). - Split each of
Q, K, Vinto heads →(B, H, T, D)withD = C / H. - Run scaled dot-product attention per head (passing
maskthrough). - Recombine heads →
(B, T, C). - Project:
out @ W_o→(B, T, C).
The parameter budget is (fused QKV) (output) per block. For GPT-2 small () that is params/block — about 23% of the 124M-parameter model.
Input
x—np.ndarrayof shape(B, T, C): the input sequence batch.W_qkv—(C, 3C): the fused projection producing Q, K, V together.W_o—(C, C): the output projection that mixes the heads back toCdims.n_head—int: the number of heads (must divideC).mask— optional(T, T)bool:True= attend,False= block. Broadcasts across the(B, H)leading dims.
Output
Returns an np.ndarray of shape (B, T, C) — same shape as x.
Examples
Example 1 — the pipeline and its shapes ()
x: (2, 4, 8), W_qkv: (8, 24), W_o: (8, 8), n_head=2
x @ W_qkv -> (2, 4, 24) # fused QKV
split Q,K,V -> three (2, 4, 8)
split_heads -> three (2, 2, 4, 4) # H=2, D=4
attention -> (2, 2, 4, 4) # per head, in parallel
combine_heads -> (2, 4, 8)
@ W_o -> (2, 4, 8) # output, same shape as x
Explanation: data flows projection → per-head attention → recombine → output projection; the output shape equals the input shape, which is what lets these layers stack.
Example 2 — n_head=1 is plain single-head attention
multi_head_attention(x, W_qkv, W_o, n_head=1)
≡ project QKV -> attention(Q, K, V, mask) -> project W_o
Explanation: with one head the split/combine are no-ops, so the layer reduces exactly to the single-head attention of build-gpt-03 wrapped in the QKV and output projections. Different n_head values reshape the attention into different per-head subspaces and therefore produce different outputs.
Constraints
n_headmust divideC; each head has sizeD = C / n_head.- Project Q, K, V together with
W_qkv(one matmul), thennp.split(..., 3, axis=-1). - The
(T, T)mask, if given, broadcasts across the(B, H)leading dims — no special handling needed. - The output projection
W_ois applied after recombining heads, mapping(B, T, C) → (B, T, C). - Output shape equals input shape
(B, T, C); tests compare against a reference withatol≈1e-6.
Notes
- Parameter budget. per attention block — three blocks fused into
W_qkvplus the output projection. This is a large slice of a transformer's parameters. - Series finale (attention). Wires together
build-gpt-03(attention),build-gpt-04(mask), andbuild-gpt-05(head split/combine);build-gpt-08drops this layer into a full transformer block.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Output shape is (B, T, C)
- •Matches reference implementation (no mask)
- •Causal mask is honoured — early positions don't see future
- •Diagnostic: H=1 reduces to a single-head attention (vs explicit ref)
- •Different H values produce different outputs (heads aren't trivially equal)