KV cache compression (MLA)Hard

KV cache compression (MLA)

Background

Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, shrinks the KV cache. Instead of caching full keys and values (2d2\,d numbers per token), MLA caches a small latent cc of dimension dcdd_c \ll d produced by a down-projection. Keys and values are reconstructed on the fly by up-projecting cc. The cache stores only the latents, so memory drops by a factor of 2d/dc\sim 2d/d_c while attention stays mathematically equivalent to caching the reconstructed K/V.

Problem statement

Implement an MLACache class holding the projections W_dkv (down), W_uk, W_uv (up for K and V):

  • append(self, x) — down-project an input token x (shape (d,)) to a latent c = x @ W_dkv (shape (d_c,)) and cache it.
  • attend(self, q) — reconstruct K = C @ W_uk, V = C @ W_uv from the cached latents C, then do scaled dot-product attention of q over them:
out=softmax ⁣(qKdk)V,K=CWuk, V=CWuv\text{out} = \operatorname{softmax}\!\Big(\frac{q\,K^\top}{\sqrt{d_k}}\Big)\,V, \quad K = C W_{uk},\ V = C W_{uv}

Return shape (1, d_v).

Input

  • W_dkv(d, d_c) down-projection.
  • W_uk(d_c, d_k) key up-projection; W_uv(d_c, d_v) value up-projection.
  • x — input token (d,) passed to append.
  • q — query (d_k,) passed to attend.

Output

  • append caches one latent row.
  • attend(q) returns an np.ndarray (1, d_v).

Examples

Example 1

With cached latents C, attend(q) equals standard attention using
K = C @ W_uk and V = C @ W_uv. A single cached token returns its reconstructed value.

Explanation: the only difference from a normal KV cache is what is stored — latents c instead of full K, V. Reconstruct, then attend exactly as usual.

Constraints

  • The cache stores latents C of shape (n_tokens, d_c), never the full K/V.
  • Reconstruct K, V from C at attention time via the up-projections.
  • Use the 1/dk1/\sqrt{d_k} scale and a stable softmax.

Notes

  • MLA's win is memory: caching d_c per token instead of 2d lets DeepSeek-V2 keep far longer contexts in the same budget.
  • Reconstructing K/V each step adds a little compute, but decoding is memory-bandwidth bound, so the trade is favorable.
Python
Loading...

This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Attention matches caching reconstructed K/V
  • Cache stores compact latents (n_tokens, d_c)
  • Single cached token returns its reconstructed value
  • Output shape is (1, d_v)