FlashAttention tiled forwardHard

FlashAttention tiled forward

Background

Standard attention materializes the full N×NN\times N score matrix, costing O(N2)O(N^2) memory. FlashAttention computes the exact same result without ever storing that matrix: it streams over blocks of keys/values and maintains an online softmax — a running max mm, a running normalizer \ell, and a running output OO — rescaling the accumulators each time a new block shifts the max. This is the trick behind training long-context transformers efficiently.

Problem statement

Implement flash_attention(Q, K, V, block_size) returning the attention output, computed by tiling over key/value blocks with online softmax. For each key block S=QKbdS = \frac{QK_b^\top}{\sqrt d}:

mnew=max(m, rowmax(S)),α=emmnew,P=eSmnewm^{\text{new}} = \max(m,\ \text{rowmax}(S)), \quad \alpha = e^{\,m - m^{\text{new}}}, \quad P = e^{\,S - m^{\text{new}}} α+rowsum(P),OαO+PVb\ell \leftarrow \alpha\,\ell + \text{rowsum}(P), \qquad O \leftarrow \alpha\,O + P V_b

After all blocks, divide: OO/O \leftarrow O/\ell. Use scale 1/d1/\sqrt{d}.

Input

  • Q, K, Vnp.ndarray of shape (N, d) (single head, no batch).
  • block_sizeint, number of keys processed per tile.

Output

An np.ndarray of shape (N, d): identical to standard softmax attention softmax(QK/d)V\operatorname{softmax}(QK^\top/\sqrt d)V.

Examples

Example 1

Input:  Q = [[1, 0]], K = [[1, 0], [0, 1]], V = [[1, 2], [3, 4]], block_size = 1
Output: [[1.6604, 2.6604]]

Explanation: scores are [1,0]/2=[0.707,0][1,0]/\sqrt2 = [0.707, 0]; softmax gives weights [0.670,0.330][0.670, 0.330]; the output is 0.670[1,2]+0.330[3,4]=[1.660,2.660]0.670\,[1,2] + 0.330\,[3,4] = [1.660, 2.660] — the same whether computed in one block or tiled.

Constraints

  • Initialize the running max to -\infty and the normalizer/output to 0; the first block's α=0\alpha = 0 correctly discards the empty initial state.
  • Rescale both \ell and OO by α=emmnew\alpha = e^{m - m^{\text{new}}} before adding the new block's contribution.
  • The result must match standard attention for any block_size in 1..N1..N.

Notes

  • The online-softmax rescaling is what keeps the result exact: subtracting the running max prevents overflow, and α\alpha corrects earlier terms once a larger max appears.
  • Real FlashAttention also tiles the queries and fuses everything into one GPU kernel; the math here is the same, only the loop nest differs.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference small example
  • Matches standard attention for a partial block size
  • block_size = N (single block) matches standard
  • block_size = 1 (fully tiled) matches standard
  • Output shape is (Nq, d)
  • Handles different query and key counts (Nq != Nk)