Softmax backward
Background
The backward pass through softmax, done the right way. The softmax Jacobian has entries — but materialising it as a matrix is a non-starter for real vocabularies (at that's ~19 GB per example at fp32). Instead, the Jacobian-vector product collapses to a single expression: one dot product and one elementwise multiply. This is exactly what PyTorch, JAX, and TensorFlow compute under the hood — the full matrix never exists.
Problem statement
Implement softmax_backward(grad_out, softmax_out). With and upstream gradient grad_out:
i.e. each — scale by how much it differs from the -weighted average of , then weight by .
Input
grad_out— 1-Dnp.ndarrayof lengthV: the upstream gradient .softmax_out— 1-Dnp.ndarrayof lengthV: the saved output .
Output
Returns a 1-D np.ndarray of the same shape — .
Examples
Example 1 — hand-checked
Input: softmax_out = [0.2, 0.3, 0.5], grad_out = [1, 0, 0]
Output: [0.16, -0.06, -0.10]
Explanation: , so . (This equals , but computed in .)
Example 2 — a constant gradient gives zero
Input: softmax_out = softmax([1,2,3]), grad_out = [0.5, 0.5, 0.5]
Output: [0, 0, 0]
Explanation: for a constant , (the 's sum to 1), so . This reflects softmax's shift-invariance — adding a constant to every logit doesn't change the output, so a uniform upstream gradient produces no input gradient.
Constraints
- Use the closed form
y * (grad_out - dot(y, grad_out))— do not build the Jacobian; the cost is , not . softmax_backwardconsumes the softmax output (the cached forward result), not the input logits.- A constant
grad_outmust give exactly0; the analytic result must match both the explicit and a finite-difference gradient (atol≈1e-3).
Notes
- Why it collapses. is the -weighted mean of the upstream gradient; subtracting it centres so the gradient is invariant to a constant shift — the algebraic shadow of softmax's input shift-invariance.
- Related. Pairs with softmax from scratch; when softmax feeds cross-entropy, the two backwards combine into the much simpler softmax − one-hot.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Output shape matches input shape
- •Diagnostic: matches the explicit Jacobian-vector product
- •Matches finite-difference numerical gradient
- •Sum of gradients is ~0 (because softmax outputs sum to 1)
- •Constant grad_out gives zero gradient (softmax invariant to constant shift)