Softmax backwardMedium

Softmax backward

Background

The backward pass through softmax, done the right way. The softmax Jacobian has entries Jij=yi(δijyj)J_{ij} = y_i(\delta_{ij} - y_j) — but materialising it as a (V,V)(V, V) matrix is a non-starter for real vocabularies (at V=50257V=50257 that's ~19 GB per example at fp32). Instead, the Jacobian-vector product collapses to a single O(V)O(V) expression: one dot product and one elementwise multiply. This is exactly what PyTorch, JAX, and TensorFlow compute under the hood — the full matrix never exists.

Problem statement

Implement softmax_backward(grad_out, softmax_out). With y=softmax(x)y = \text{softmax}(x) and upstream gradient g=g = grad_out:

Jij=yi(δijyj),dx=y(gy,g)J_{ij} = y_i\,(\delta_{ij} - y_j), \qquad dx = y \odot \big(g - \langle y, g\rangle\big)

i.e. each dxi=yi(gijyjgj)dx_i = y_i\big(g_i - \sum_j y_j g_j\big) — scale gig_i by how much it differs from the yy-weighted average of gg, then weight by yiy_i.

Input

  • grad_out — 1-D np.ndarray of length V: the upstream gradient L/y\partial L/\partial y.
  • softmax_out — 1-D np.ndarray of length V: the saved output y=softmax(x)y = \text{softmax}(x).

Output

Returns a 1-D np.ndarray of the same shape — L/x\partial L/\partial x.

Examples

Example 1 — hand-checked

Input:  softmax_out = [0.2, 0.3, 0.5], grad_out = [1, 0, 0]
Output: [0.16, -0.06, -0.10]

Explanation: y,g=0.21=0.2\langle y, g\rangle = 0.2\cdot1 = 0.2, so dx=y(g0.2)=[0.20.8, 0.3(0.2), 0.5(0.2)]=[0.16,0.06,0.10]dx = y\odot(g - 0.2) = [0.2\cdot0.8,\ 0.3\cdot(-0.2),\ 0.5\cdot(-0.2)] = [0.16, -0.06, -0.10]. (This equals (diag(y)yy)g(\operatorname{diag}(y) - yy^\top)\,g, but computed in O(V)O(V).)

Example 2 — a constant gradient gives zero

Input:  softmax_out = softmax([1,2,3]), grad_out = [0.5, 0.5, 0.5]
Output: [0, 0, 0]

Explanation: for a constant gg, y,g=0.5\langle y, g\rangle = 0.5 (the yy's sum to 1), so dx=y(0.50.5)=0dx = y\odot(0.5 - 0.5) = 0. This reflects softmax's shift-invariance — adding a constant to every logit doesn't change the output, so a uniform upstream gradient produces no input gradient.

Constraints

  • Use the closed form y * (grad_out - dot(y, grad_out))do not build the (V,V)(V, V) Jacobian; the cost is O(V)O(V), not O(V2)O(V^2).
  • softmax_backward consumes the softmax output yy (the cached forward result), not the input logits.
  • A constant grad_out must give exactly 0; the analytic result must match both the explicit JgJ g and a finite-difference gradient (atol≈1e-3).

Notes

  • Why it collapses. y,g\langle y, g\rangle is the yy-weighted mean of the upstream gradient; subtracting it centres gg so the gradient is invariant to a constant shift — the algebraic shadow of softmax's input shift-invariance.
  • Related. Pairs with softmax from scratch; when softmax feeds cross-entropy, the two backwards combine into the much simpler softmax − one-hot.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Output shape matches input shape
  • Diagnostic: matches the explicit Jacobian-vector product
  • Matches finite-difference numerical gradient
  • Sum of gradients is ~0 (because softmax outputs sum to 1)
  • Constant grad_out gives zero gradient (softmax invariant to constant shift)