Multi-Head Attention (full layer)Medium

Multi-Head Attention (full layer)

Background

This is the full Multi-Head Attention layer — the unit stacked twelve times in GPT-2 small. It composes everything earlier in the series: a fused QKV projection, the head split (build-gpt-05), single-head scaled dot-product attention run in parallel across heads (build-gpt-03), the optional causal mask (build-gpt-04), the head recombine (build-gpt-05), and a final output projection. Multiple heads let the model attend to different relationships in parallel, then mix them back together.

Problem statement

Implement multi_head_attention(x, W_qkv, W_o, n_head, mask=None) in five steps:

  1. Project: qkv = x @ W_qkv, shape (B, T, 3C), then split into Q, K, V (each (B, T, C)).
  2. Split each of Q, K, V into heads → (B, H, T, D) with D = C / H.
  3. Run scaled dot-product attention per head (passing mask through).
  4. Recombine heads → (B, T, C).
  5. Project: out @ W_o(B, T, C).

The parameter budget is 3C23C^2 (fused QKV) +C2+ C^2 (output) =4C2= 4C^2 per block. For GPT-2 small (C=768C = 768) that is 2.36M\approx 2.36\text{M} params/block ×1228M\times 12 \approx 28\text{M} — about 23% of the 124M-parameter model.

Input

  • xnp.ndarray of shape (B, T, C): the input sequence batch.
  • W_qkv(C, 3C): the fused projection producing Q, K, V together.
  • W_o(C, C): the output projection that mixes the heads back to C dims.
  • n_headint: the number of heads (must divide C).
  • mask — optional (T, T) bool: True = attend, False = block. Broadcasts across the (B, H) leading dims.

Output

Returns an np.ndarray of shape (B, T, C) — same shape as x.

Examples

Example 1 — the pipeline and its shapes (B=2,T=4,C=8,H=2B=2, T=4, C=8, H=2)

x: (2, 4, 8), W_qkv: (8, 24), W_o: (8, 8), n_head=2
  x @ W_qkv      -> (2, 4, 24)       # fused QKV
  split Q,K,V    -> three (2, 4, 8)
  split_heads    -> three (2, 2, 4, 4)   # H=2, D=4
  attention      -> (2, 2, 4, 4)         # per head, in parallel
  combine_heads  -> (2, 4, 8)
  @ W_o          -> (2, 4, 8)            # output, same shape as x

Explanation: data flows projection → per-head attention → recombine → output projection; the output shape equals the input shape, which is what lets these layers stack.

Example 2 — n_head=1 is plain single-head attention

multi_head_attention(x, W_qkv, W_o, n_head=1)
  ≡ project QKV -> attention(Q, K, V, mask) -> project W_o

Explanation: with one head the split/combine are no-ops, so the layer reduces exactly to the single-head attention of build-gpt-03 wrapped in the QKV and output projections. Different n_head values reshape the attention into different per-head subspaces and therefore produce different outputs.

Constraints

  • n_head must divide C; each head has size D = C / n_head.
  • Project Q, K, V together with W_qkv (one matmul), then np.split(..., 3, axis=-1).
  • The (T, T) mask, if given, broadcasts across the (B, H) leading dims — no special handling needed.
  • The output projection W_o is applied after recombining heads, mapping (B, T, C) → (B, T, C).
  • Output shape equals input shape (B, T, C); tests compare against a reference with atol≈1e-6.

Notes

  • Parameter budget. 4C24C^2 per attention block — three C×CC\times C blocks fused into W_qkv plus the C×CC\times C output projection. This is a large slice of a transformer's parameters.
  • Series finale (attention). Wires together build-gpt-03 (attention), build-gpt-04 (mask), and build-gpt-05 (head split/combine); build-gpt-08 drops this layer into a full transformer block.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Output shape is (B, T, C)
  • Matches reference implementation (no mask)
  • Causal mask is honoured — early positions don't see future
  • Diagnostic: H=1 reduces to a single-head attention (vs explicit ref)
  • Different H values produce different outputs (heads aren't trivially equal)