Causal mask: build + applyEasy

Causal mask: build + apply

Background

GPT-2 is a decoder: at training time it predicts token t+1t+1 from tokens 0..t0..t, so position tt must not attend to any position >t> t — otherwise the model "cheats" by reading the answer straight out of its input. The fix is a causal (lower-triangular) mask applied to the attention scores before softmax. It is one triangular matrix and one where call, but getting the placement right is what makes future positions receive exactly zero probability.

Problem statement

Implement two functions:

  1. build_causal_mask(T) — return a (T, T) boolean mask with mask[i, j] = True when jij \le i (position ii may attend to position jj) and False when j>ij > i. This is the lower triangle including the diagonal.
  2. apply_causal_mask(scores) — given a (T, T) score matrix (from QK/dQK^\top/\sqrt{d}), return a copy with the upper triangle (j>ij > i) set to -\infty and the rest unchanged, so a subsequent softmax sends those entries to exactly 0.

Input

  • build_causal_mask(T)T: int, the sequence length.
  • apply_causal_mask(scores)scores: np.ndarray of shape (T, T), the raw attention scores.

Output

  • build_causal_mask returns a (T, T) bool array (lower-triangular True).
  • apply_causal_mask returns a (T, T) float array: original scores at/below the diagonal, -\infty above it.

Examples

Example 1 — the mask for T = 4

build_causal_mask(4) =
  [[ True, False, False, False],
   [ True,  True, False, False],
   [ True,  True,  True, False],
   [ True,  True,  True,  True]]

Explanation: row ii has True in columns 0..i0..i (it can see itself and the past) and False afterward — so row ii has exactly i+1i+1 True entries.

Example 2 — applying it, then softmax

Input:  scores = np.full((4, 4), 1.0)
apply_causal_mask(scores) =
  [[ 1, -inf, -inf, -inf],
   [ 1,    1, -inf, -inf],
   [ 1,    1,    1, -inf],
   [ 1,    1,    1,    1]]
row 0 after softmax = [1, 0, 0, 0]      # future positions are exactly 0

Explanation: the upper triangle is replaced with -\infty while the lower triangle keeps its scores; after a row-wise softmax, the -\infty entries become exactly 00 and each row's surviving weights sum to 1 (row 0 attends only to position 0).

Constraints

  • mask[i, j] = True iff jij \le i; row ii contains exactly i+1i + 1 True values.
  • Mask before softmax with -\infty (or a large negative number) so masked probabilities are exactly 0 — not 0.0010.001.
  • Do not zero weights after softmax: that leaves each row summing to less than 1 and corrupts the gradient.
  • apply_causal_mask leaves the at/below-diagonal scores unchanged.

Notes

  • -inf vs -1e9. Both work; -1e9 is friendlier to fp16, while -np.inf is conceptually cleaner (exp(-inf) = 0 exactly). np.tril(np.ones((T, T), bool)) builds the mask in one call.
  • Series. apply_causal_mask plugs into build-gpt-06-multi-head-attention; this is the masking half of the attention stack alongside build-gpt-03.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • build_causal_mask: shape and dtype
  • build_causal_mask: diagonal True; below diagonal True; above False
  • apply_causal_mask: upper triangle becomes -inf
  • After softmax, masked positions are exactly 0 — the diagnostic
  • Larger T works without quadratic blowup