Clip gradients by global L2 normEasy

Clip gradients by global L2 norm

Background

Gradient clipping by global L2 norm is the one line of code that prevents most exploding-gradient training crashes — GPT-2, GPT-3, and essentially every transformer run use it. The idea: treat all the gradient tensors as one giant concatenated vector, measure its length, and if it exceeds a threshold, shrink the whole thing back down by a single scalar. Because every component scales by the same factor, the update direction is preserved — only its magnitude is capped.

Problem statement

Implement clip_grad_norm(grads, max_norm). Compute the global norm across all tensors, scale them down in place if it exceeds the threshold, and return the original norm:

total=igi22,gigimax_normtotalif total>max_norm\text{total} = \sqrt{\sum_i \lVert g_i \rVert_2^2}, \qquad g_i \leftarrow g_i \cdot \frac{\text{max\_norm}}{\text{total}} \quad\text{if } \text{total} > \text{max\_norm}

If total <= max_norm, leave the gradients unchanged. Either way, return total (the norm before clipping).

Input

  • gradslist[np.ndarray]: the gradient tensors. Mutated in place when clipping occurs.
  • max_normfloat >0> 0: the clipping threshold.

Output

Returns a float: the original combined L2 norm (before any scaling) — useful for logging "did we clip this step?".

Examples

Example 1 — below threshold, no change

Input:  grads = [[3, 4], [0, 12]], max_norm = 100
Output: returns 13.0     (gradients unchanged)

Explanation: g1=5\lVert g_1\rVert = 5, g2=12\lVert g_2\rVert = 12, so total =52+122=13= \sqrt{5^2 + 12^2} = 13. Since 1310013 \le 100, nothing is scaled — but the function still returns the original norm 13.0.

Example 2 — above threshold, scaled down

Input:  grads = [[6, 8]], max_norm = 2
Output: returns 10.0;  grads -> [[1.2, 1.6]]

Explanation: g=10>2\lVert g\rVert = 10 > 2, so every element scales by 2/10=0.22/10 = 0.2, giving [1.2, 1.6] whose norm is exactly max_norm = 2. The return value is the original norm 10.0 (for logging), not the clipped one.

Constraints

  • The norm is globalsqrt(sum over all tensors of (g**2).sum()) — as if everything were one flat vector. Clipping per-tensor instead would rotate the joint direction.
  • If total <= max_norm, do nothing; otherwise multiply every tensor by max_norm / total.
  • Mutate in place (g *= scale, not g = g * scale) so the caller's references update.
  • Return the original norm; after clipping, the new global norm equals max_norm exactly, and the direction is unchanged (pure scalar shrink).

Notes

  • Why return the pre-clip norm. PyTorch's clip_grad_norm_ does the same — logging it lets you watch training health: a value hovering well below max_norm is healthy; frequent spikes past it (sustained climb) are the early warning of loss divergence.
  • Related. Like SGD step, the in-place mutation is part of the contract — clipping runs between the backward pass and the optimiser step.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Returns the original total L2 norm
  • No clipping when total norm <= max_norm
  • Diagnostic: when total norm > max_norm, gradients are scaled to exactly max_norm
  • Direction is preserved (gradients are scaled, not zeroed or rotated)
  • Mutation is in place (caller's references see the new values)
  • Single tensor with very small max_norm shrinks correctly