KV cache compression (MLA)
Background
Multi-head Latent Attention (MLA), introduced in DeepSeek-V2, shrinks the KV cache. Instead of caching full keys and values ( numbers per token), MLA caches a small latent of dimension produced by a down-projection. Keys and values are reconstructed on the fly by up-projecting . The cache stores only the latents, so memory drops by a factor of while attention stays mathematically equivalent to caching the reconstructed K/V.
Problem statement
Implement an MLACache class holding the projections W_dkv (down), W_uk, W_uv (up for K and V):
append(self, x)— down-project an input tokenx(shape(d,)) to a latentc = x @ W_dkv(shape(d_c,)) and cache it.attend(self, q)— reconstructK = C @ W_uk,V = C @ W_uvfrom the cached latentsC, then do scaled dot-product attention ofqover them:
Return shape (1, d_v).
Input
W_dkv—(d, d_c)down-projection.W_uk—(d_c, d_k)key up-projection;W_uv—(d_c, d_v)value up-projection.x— input token(d,)passed toappend.q— query(d_k,)passed toattend.
Output
appendcaches one latent row.attend(q)returns annp.ndarray(1, d_v).
Examples
Example 1
With cached latents C, attend(q) equals standard attention using
K = C @ W_uk and V = C @ W_uv. A single cached token returns its reconstructed value.
Explanation: the only difference from a normal KV cache is what is stored — latents c instead of full K, V. Reconstruct, then attend exactly as usual.
Constraints
- The cache stores latents
Cof shape(n_tokens, d_c), never the full K/V. - Reconstruct
K,VfromCat attention time via the up-projections. - Use the scale and a stable softmax.
Notes
- MLA's win is memory: caching
d_cper token instead of2dlets DeepSeek-V2 keep far longer contexts in the same budget. - Reconstructing K/V each step adds a little compute, but decoding is memory-bandwidth bound, so the trade is favorable.
This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Attention matches caching reconstructed K/V
- •Cache stores compact latents (n_tokens, d_c)
- •Single cached token returns its reconstructed value
- •Output shape is (1, d_v)