Grouped-query attention
Background
Grouped-Query Attention (GQA) sits between multi-head attention (MHA) and multi-query attention (MQA). MHA gives every query head its own key/value head — accurate but with a large KV cache. MQA shares a single K/V head across all query heads — tiny cache, some quality loss. GQA splits the query heads into groups; each group shares one K/V head. With between 1 and you trade cache size against quality. Llama-2/3 use GQA for exactly this reason.
Problem statement
Implement grouped_query_attention(Q, K, V) returning per-head scaled dot-product attention where query heads share K/V heads by group.
- Query head uses K/V group .
- For each head: .
Raise ValueError if is not divisible by .
Input
Q—np.ndarrayof shape(H, N, d)—Hquery heads.K,V—np.ndarrayof shape(G, N, d)—Gkey/value groups ( divides ).
Output
An np.ndarray of shape (H, N, d): the attention output for every query head.
Examples
Example 1
Input: H = 4 query heads, G = 2 kv groups, N = 1 key position, d = 2
V groups = [[[1, 1]], [[9, 9]]]
Output: heads 0,1 -> [[1, 1]] (share group 0)
heads 2,3 -> [[9, 9]] (share group 1)
Explanation: with H/G = 2 heads per group, query heads 0–1 map to group 0 and heads 2–3 to group 1. With a single key position the softmax weight is 1, so each head's output equals its group's value.
Constraints
- The group index for head
hish // (H // G)(integer division). - Apply softmax over the key axis, with the scale.
G == Hrecovers standard multi-head attention;G == 1recovers multi-query attention.
Notes
- The KV cache shrinks by a factor of versus MHA, which is the whole point during long-context inference.
- Heads within the same group see identical keys/values, so they differ only through their distinct query projections.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Reference: head-to-group routing with a single key
- •G == H reduces to standard multi-head attention
- •G == 1 reduces to multi-query attention (shared K/V)
- •Output shape is (H, N, d)
- •Indivisible head/group count raises ValueError