Multi-head split + combineMedium

Multi-head split + combine

Background

Multi-head attention runs n_head attention computations in parallel over slices of the feature dimension. To do that efficiently you reshape a (B, T, C) tensor so each head's features sit on their own axis, run attention, then stitch the heads back together. The whole trick is the order of operations: reshape then transpose to split, transpose then reshape to combine. Get the order wrong and the bytes are reinterpreted in the wrong layout — the data scrambles silently.

Problem statement

Implement two inverse functions, with C=HDC = H \cdot D (where H=H = n_head and DD is the per-head size):

  1. split_heads(x, n_head) — turn (B, T, C) into (B, H, T, D) by reshaping to (B, T, H, D) (splitting the last axis), then transposing the head axis in front of the time axis.
  2. combine_heads(x) — the exact inverse: turn (B, H, T, D) back into (B, T, C) by transposing to (B, T, H, D) first, then reshaping to merge the last two axes into H·D.

Input

  • split_heads(x, n_head)x: (B, T, C); n_head: int that must divide C.
  • combine_heads(x)x: (B, H, T, D).

Output

  • split_heads returns (B, n_head, T, head_size) with head_size = C // n_head.
  • combine_heads returns (B, T, H * D).

Examples

Example 1 — the shape walk

split_heads(x, n_head=3),  x: (2, 8, 12)
  reshape   -> (2, 8, 3, 4)     # split C=12 into H=3, D=4
  transpose -> (2, 3, 8, 4)     # move the head axis ahead of time

combine_heads(y),  y: (2, 4, 8, 6)
  transpose -> (2, 8, 4, 6)
  reshape   -> (2, 8, 24)       # merge H*D = 24

Explanation: the reshape only relabels the last axis; the transpose physically reorders so each head is on its own axis. combine_heads undoes both steps in reverse order.

Example 2 — which features go to which head, and the round-trip

Input:  x[b, t, :] = [0, 1, 2, 3, 4, 5, 6, 7] for every (b, t);  n_head=2  (so D=4)
split_heads -> head 0 holds [0, 1, 2, 3],  head 1 holds [4, 5, 6, 7]
combine_heads(split_heads(x, 2)) == x        # exact, atol 1e-12

Explanation: splitting the last axis into contiguous chunks gives head 0 the first DD features and head 1 the next DD. Because combine_heads is the exact inverse, splitting then combining returns x bit-for-bit.

Constraints

  • C = H · D; n_head must divide C, otherwise raise (AssertionError or ValueError).
  • split = reshape then transpose; combine = transpose then reshape. The reverse order silently scrambles the head/time layout.
  • combine_heads(split_heads(x, n_head)) must equal x exactly (atol=1e-12).
  • Head hh holds feature indices [hD, (h+1)D)[h\cdot D,\ (h+1)\cdot D) of the original C axis.

Notes

  • Why order matters. A reshape reinterprets the existing byte order without moving anything; a transpose moves axes. Reshaping (B,T,C)→(B,T,H,D) keeps each token's features together, and only then does the transpose lift the head axis out — doing it the other way mixes heads with time.
  • Series. build-gpt-06-multi-head-attention calls these to wrap the single-head attention from build-gpt-03 into a full multi-head layer.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • split_heads: shape (B, T, C) -> (B, H, T, D)
  • combine_heads: shape (B, H, T, D) -> (B, T, C)
  • Diagnostic: split then combine is identity
  • split_heads: head 0 contains the first head_size features along C
  • Asserts cleanly when n_head does not divide C
  • Larger GPT-2-scale shapes work