KV cache for autoregressive inferenceMedium

KV cache for autoregressive inference

Background

During autoregressive generation a transformer produces one token at a time. Naively, each new step would recompute the keys and values for every previous token — O(n2)O(n^2) work to emit nn tokens. The KV cache stores the keys/values already computed and only appends the new token's K/V each step, so attention at step tt reuses the cached tt rows. This is the single most important optimization for fast LLM decoding.

Problem statement

Implement a KVCache class with:

  • append(self, k, v) — append a new key row k and value row v (each shape (d,) or (1, d)) to the cache.
  • attend(self, q) — scaled dot-product attention of a new query q over all cached keys/values:
out=softmax ⁣(qKcached)Vcache\text{out} = \operatorname{softmax}\!\Big(\frac{q\, K_{\text{cache}}^\top}{\sqrt d}\Big)\, V_{\text{cache}}

Return the output of shape (1, d). Use the max-subtraction trick for stability.

Input

  • k, v — new key/value vectors of shape (d,) (or (1, d)), appended per step.
  • q — query vector of shape (d,) (or (1, d)) for attend.

Output

  • append updates the cache in place.
  • attend(q) returns an np.ndarray of shape (1, d): attention over everything cached so far.

Examples

Example 1

cache = KVCache()
cache.append([1, 0], [10, 20]); cache.attend([1, 0])  -> [[10, 20]]
cache.append([1, 0], [30, 40]); cache.attend([1, 0])  -> [[20, 30]]

Explanation: at step 1 only one key is cached, so its value passes through. At step 2 both keys are identical, so the softmax weights are 0.5/0.5 and the output is the average of the two cached values, [20,30][20, 30].

Constraints

  • attend must score the query against all cached keys, then softmax over them.
  • Stack appended rows so the cache grows by one row per append.
  • Scale scores by 1/d1/\sqrt d and subtract the row max before exponentiating.

Notes

  • Incremental cached decoding produces the exact same outputs as recomputing full causal-masked attention each step — it is purely a compute/memory optimization, not an approximation.
  • The cache size grows linearly with sequence length, which is why long-context inference is memory-bound (and motivates tricks like MQA/GQA and MLA).
Python
Loading...

This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference two-step example
  • Cached decoding matches dense causal attention
  • Cache grows by one row per append
  • Single cached token returns its value