Grouped-query attentionMedium

Grouped-query attention

Background

Grouped-Query Attention (GQA) sits between multi-head attention (MHA) and multi-query attention (MQA). MHA gives every query head its own key/value head — accurate but with a large KV cache. MQA shares a single K/V head across all query heads — tiny cache, some quality loss. GQA splits the HH query heads into GG groups; each group shares one K/V head. With GG between 1 and HH you trade cache size against quality. Llama-2/3 use GQA for exactly this reason.

Problem statement

Implement grouped_query_attention(Q, K, V) returning per-head scaled dot-product attention where query heads share K/V heads by group.

  • Query head hh uses K/V group g=h//(H/G)g = h \,//\,(H/G).
  • For each head: Oh=softmax ⁣(QhKgd)VgO_h = \operatorname{softmax}\!\big(\frac{Q_h K_g^\top}{\sqrt d}\big) V_g.

Raise ValueError if HH is not divisible by GG.

Input

  • Qnp.ndarray of shape (H, N, d)H query heads.
  • K, Vnp.ndarray of shape (G, N, d)G key/value groups (GG divides HH).

Output

An np.ndarray of shape (H, N, d): the attention output for every query head.

Examples

Example 1

Input:  H = 4 query heads, G = 2 kv groups, N = 1 key position, d = 2
        V groups = [[[1, 1]], [[9, 9]]]
Output: heads 0,1 -> [[1, 1]]   (share group 0)
        heads 2,3 -> [[9, 9]]   (share group 1)

Explanation: with H/G = 2 heads per group, query heads 0–1 map to group 0 and heads 2–3 to group 1. With a single key position the softmax weight is 1, so each head's output equals its group's value.

Constraints

  • The group index for head h is h // (H // G) (integer division).
  • Apply softmax over the key axis, with the 1/d1/\sqrt d scale.
  • G == H recovers standard multi-head attention; G == 1 recovers multi-query attention.

Notes

  • The KV cache shrinks by a factor of H/GH/G versus MHA, which is the whole point during long-context inference.
  • Heads within the same group see identical keys/values, so they differ only through their distinct query projections.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference: head-to-group routing with a single key
  • G == H reduces to standard multi-head attention
  • G == 1 reduces to multi-query attention (shared K/V)
  • Output shape is (H, N, d)
  • Indivisible head/group count raises ValueError