Clip gradients by global L2 norm
Background
Gradient clipping by global L2 norm is the one line of code that prevents most exploding-gradient training crashes — GPT-2, GPT-3, and essentially every transformer run use it. The idea: treat all the gradient tensors as one giant concatenated vector, measure its length, and if it exceeds a threshold, shrink the whole thing back down by a single scalar. Because every component scales by the same factor, the update direction is preserved — only its magnitude is capped.
Problem statement
Implement clip_grad_norm(grads, max_norm). Compute the global norm across all tensors, scale them down in place if it exceeds the threshold, and return the original norm:
If total <= max_norm, leave the gradients unchanged. Either way, return total (the norm before clipping).
Input
grads—list[np.ndarray]: the gradient tensors. Mutated in place when clipping occurs.max_norm—float: the clipping threshold.
Output
Returns a float: the original combined L2 norm (before any scaling) — useful for logging "did we clip this step?".
Examples
Example 1 — below threshold, no change
Input: grads = [[3, 4], [0, 12]], max_norm = 100
Output: returns 13.0 (gradients unchanged)
Explanation: , , so total . Since , nothing is scaled — but the function still returns the original norm 13.0.
Example 2 — above threshold, scaled down
Input: grads = [[6, 8]], max_norm = 2
Output: returns 10.0; grads -> [[1.2, 1.6]]
Explanation: , so every element scales by , giving [1.2, 1.6] whose norm is exactly max_norm = 2. The return value is the original norm 10.0 (for logging), not the clipped one.
Constraints
- The norm is global —
sqrt(sum over all tensors of (g**2).sum())— as if everything were one flat vector. Clipping per-tensor instead would rotate the joint direction. - If
total <= max_norm, do nothing; otherwise multiply every tensor bymax_norm / total. - Mutate in place (
g *= scale, notg = g * scale) so the caller's references update. - Return the original norm; after clipping, the new global norm equals
max_normexactly, and the direction is unchanged (pure scalar shrink).
Notes
- Why return the pre-clip norm. PyTorch's
clip_grad_norm_does the same — logging it lets you watch training health: a value hovering well belowmax_normis healthy; frequent spikes past it (sustained climb) are the early warning of loss divergence. - Related. Like SGD step, the in-place mutation is part of the contract — clipping runs between the backward pass and the optimiser step.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Returns the original total L2 norm
- •No clipping when total norm <= max_norm
- •Diagnostic: when total norm > max_norm, gradients are scaled to exactly max_norm
- •Direction is preserved (gradients are scaled, not zeroed or rotated)
- •Mutation is in place (caller's references see the new values)
- •Single tensor with very small max_norm shrinks correctly