FlashAttention tiled forward
Background
Standard attention materializes the full score matrix, costing memory. FlashAttention computes the exact same result without ever storing that matrix: it streams over blocks of keys/values and maintains an online softmax — a running max , a running normalizer , and a running output — rescaling the accumulators each time a new block shifts the max. This is the trick behind training long-context transformers efficiently.
Problem statement
Implement flash_attention(Q, K, V, block_size) returning the attention output, computed by tiling over key/value blocks with online softmax. For each key block :
After all blocks, divide: . Use scale .
Input
Q,K,V—np.ndarrayof shape(N, d)(single head, no batch).block_size—int, number of keys processed per tile.
Output
An np.ndarray of shape (N, d): identical to standard softmax attention .
Examples
Example 1
Input: Q = [[1, 0]], K = [[1, 0], [0, 1]], V = [[1, 2], [3, 4]], block_size = 1
Output: [[1.6604, 2.6604]]
Explanation: scores are ; softmax gives weights ; the output is — the same whether computed in one block or tiled.
Constraints
- Initialize the running max to and the normalizer/output to 0; the first block's correctly discards the empty initial state.
- Rescale both and by before adding the new block's contribution.
- The result must match standard attention for any
block_sizein .
Notes
- The online-softmax rescaling is what keeps the result exact: subtracting the running max prevents overflow, and corrects earlier terms once a larger max appears.
- Real FlashAttention also tiles the queries and fuses everything into one GPU kernel; the math here is the same, only the loop nest differs.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Reference small example
- •Matches standard attention for a partial block size
- •block_size = N (single block) matches standard
- •block_size = 1 (fully tiled) matches standard
- •Output shape is (Nq, d)
- •Handles different query and key counts (Nq != Nk)