Multi-head split + combine
Background
Multi-head attention runs n_head attention computations in parallel over slices of the feature dimension. To do that efficiently you reshape a (B, T, C) tensor so each head's features sit on their own axis, run attention, then stitch the heads back together. The whole trick is the order of operations: reshape then transpose to split, transpose then reshape to combine. Get the order wrong and the bytes are reinterpreted in the wrong layout — the data scrambles silently.
Problem statement
Implement two inverse functions, with (where n_head and is the per-head size):
split_heads(x, n_head)— turn(B, T, C)into(B, H, T, D)by reshaping to(B, T, H, D)(splitting the last axis), then transposing the head axis in front of the time axis.combine_heads(x)— the exact inverse: turn(B, H, T, D)back into(B, T, C)by transposing to(B, T, H, D)first, then reshaping to merge the last two axes intoH·D.
Input
split_heads(x, n_head)—x:(B, T, C);n_head:intthat must divideC.combine_heads(x)—x:(B, H, T, D).
Output
split_headsreturns(B, n_head, T, head_size)withhead_size = C // n_head.combine_headsreturns(B, T, H * D).
Examples
Example 1 — the shape walk
split_heads(x, n_head=3), x: (2, 8, 12)
reshape -> (2, 8, 3, 4) # split C=12 into H=3, D=4
transpose -> (2, 3, 8, 4) # move the head axis ahead of time
combine_heads(y), y: (2, 4, 8, 6)
transpose -> (2, 8, 4, 6)
reshape -> (2, 8, 24) # merge H*D = 24
Explanation: the reshape only relabels the last axis; the transpose physically reorders so each head is on its own axis. combine_heads undoes both steps in reverse order.
Example 2 — which features go to which head, and the round-trip
Input: x[b, t, :] = [0, 1, 2, 3, 4, 5, 6, 7] for every (b, t); n_head=2 (so D=4)
split_heads -> head 0 holds [0, 1, 2, 3], head 1 holds [4, 5, 6, 7]
combine_heads(split_heads(x, 2)) == x # exact, atol 1e-12
Explanation: splitting the last axis into contiguous chunks gives head 0 the first features and head 1 the next . Because combine_heads is the exact inverse, splitting then combining returns x bit-for-bit.
Constraints
C = H · D;n_headmust divideC, otherwise raise (AssertionErrororValueError).- split = reshape then transpose; combine = transpose then reshape. The reverse order silently scrambles the head/time layout.
combine_heads(split_heads(x, n_head))must equalxexactly (atol=1e-12).- Head holds feature indices of the original
Caxis.
Notes
- Why order matters. A reshape reinterprets the existing byte order without moving anything; a transpose moves axes. Reshaping
(B,T,C)→(B,T,H,D)keeps each token's features together, and only then does the transpose lift the head axis out — doing it the other way mixes heads with time. - Series.
build-gpt-06-multi-head-attentioncalls these to wrap the single-headattentionfrombuild-gpt-03into a full multi-head layer.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •split_heads: shape (B, T, C) -> (B, H, T, D)
- •combine_heads: shape (B, H, T, D) -> (B, T, C)
- •Diagnostic: split then combine is identity
- •split_heads: head 0 contains the first head_size features along C
- •Asserts cleanly when n_head does not divide C
- •Larger GPT-2-scale shapes work