Efficient sparse window attentionMedium

Efficient sparse window attention

Background

Dense attention costs O(n2)O(n^2) in sequence length — prohibitive for long contexts. Sliding-window (local) attention restricts each query to a fixed window of nearby keys, cutting cost to O(nw)O(n\cdot w). It is the backbone of long-context models like Longformer and Mistral: most relevant context is local, and stacking windowed layers still grows the effective receptive field.

Problem statement

Implement sparse_window_attention(Q, K, V, window_size, scale_factor=None). For each query position ii, attend only to keys in the window [iw, i+w][\,i-w,\ i+w\,] (clipped to the sequence):

outi=softmax ⁣(QiKwins)Vwin,win=[max(0,iw), min(n,i+w+1))\text{out}_i = \operatorname{softmax}\!\Big(\frac{Q_i\, K_{\text{win}}^\top}{s}\Big)\, V_{\text{win}}, \qquad \text{win} = [\max(0,i-w),\ \min(n,\, i+w+1))

where ww is window_size and ss is scale_factor (defaults to dk\sqrt{d_k}). Use the max-subtraction trick for a stable softmax.

Input

  • Q, Knp.ndarray of shape (seq_len, d_k).
  • Vnp.ndarray of shape (seq_len, d_v).
  • window_sizeint, the window radius (positions attended on each side).
  • scale_factorfloat or None; if None, use dk\sqrt{d_k}.

Output

An np.ndarray of shape (seq_len, d_v): the windowed attention output.

Examples

Example 1

Input:  Q = [[1],[1],[1]], K = [[1],[1],[1]], V = [[1],[2],[3]], window_size = 1
Output: [[1.5], [2.0], [2.5]]

Explanation: position 0 sees V[0:2] → mean 1.5; position 1 sees all three → mean 2.0; position 2 sees V[1:3] → mean 2.5. All scores are equal here, so each window is a simple average.

Constraints

  • The window is inclusive on both sides: start = max(0, i - w), end = min(seq_len, i + w + 1).
  • Softmax is computed only over the window (with max-subtraction), then applied to the window's values.
  • scale_factor defaults to dk\sqrt{d_k}, matching standard scaled dot-product attention.

Notes

  • A window of radius wn1w \ge n-1 recovers full dense attention (every query sees every key).
  • A window of radius 0 makes each token attend only to itself, so the output equals V.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • Window radius >= seq_len recovers dense attention
  • Window radius 0 makes each token attend to itself (output == V)
  • Output shape is (seq_len, d_v)
  • A query only sees keys inside its window