GELU backward (tanh approximation)Medium

GELU backward (tanh approximation)

Background

GELU is the smooth activation that replaced ReLU in transformers (GPT, BERT). The forward (tanh approximation) is one line; the backward is a good product-rule + chain-rule workout — there are five separate places x appears in the derivative, so it is easy to drop a term. Getting it right (and checking it against a numerical gradient) is exactly the skill you need to debug a real backprop bug.

Problem statement

Implement gelu_backward(grad_out, x), the derivative of GELU (tanh approximation) w.r.t. its input, times the upstream gradient. With

g(x)=2π(x+0.044715x3),g(x)=2π(1+0.134145x2)g(x) = \sqrt{\tfrac{2}{\pi}}\,\big(x + 0.044715\,x^3\big), \qquad g'(x) = \sqrt{\tfrac{2}{\pi}}\,\big(1 + 0.134145\,x^2\big)

the derivative of GELU(x)=0.5x(1+tanhg(x))\text{GELU}(x) = 0.5\,x\,(1 + \tanh g(x)) is

dGELUdx=0.5(1+tanhg)+0.5x(1tanh2g)g(x)\frac{d\,\text{GELU}}{dx} = 0.5\,(1 + \tanh g) + 0.5\,x\,(1 - \tanh^2 g)\,g'(x)

Return grad_out * d_GELU/dx. (The constant 0.134145=3×0.0447150.134145 = 3\times0.044715 comes from differentiating the x3x^3 term.)

Input

  • grad_outnp.ndarray: the upstream gradient w.r.t. the GELU output (same shape as x).
  • xnp.ndarray: the input that was fed to GELU (not the output).

Output

Returns an np.ndarray of the same shape — the gradient w.r.t. x.

Examples

Example 1 — the slope at x=0x=0

Input:  grad_out = [1.0], x = [0.0]
Output: [0.5]

Explanation: g(0)=0g(0)=0 so tanhg=0\tanh g = 0; the second term carries a factor of x=0x=0, so the derivative is 0.5(1+0)+0=0.50.5(1+0) + 0 = 0.5.

Example 2 — the asymptotic slopes

Input:  x = [10, 100, 1000],  grad_out = [1,1,1]   -> ≈ [1, 1, 1]
        x = [-10, -100],      grad_out = [1,1]     -> ≈ [0, 0]

Explanation: as x+x\to+\infty, tanhg1\tanh g\to 1 and the second term decays, so the slope 1\to 1 (GELU \approx identity for large positive xx). As xx\to-\infty, tanhg1\tanh g\to -1 and both terms collapse, so the slope 0\to 0 (GELU saturates to 0, like ReLU's left side but smooth).

Constraints

  • Use the product rule on 0.5x(1+tanhg)0.5\,x\,(1+\tanh g) plus the chain rule on tanh\tanh; ddxtanhg=(1tanh2g)g(x)\frac{d}{dx}\tanh g = (1-\tanh^2 g)\,g'(x).
  • gelu_backward takes the input x, not the GELU output.
  • Multiply the local derivative by grad_out (it scales the result elementwise).
  • The analytic gradient must match a finite-difference numerical gradient within atol≈1e-3 — the test that catches any sign or missing-term error.

Notes

  • Compute tanh(g(x)) once. np.tanh is the expensive call, and you need it for both (1+tanhg)(1+\tanh g) and (1tanh2g)(1-\tanh^2 g); reuse it rather than recomputing.
  • Pairs with the forward. This is the backward for GELU forward; together they're the activation inside every transformer MLP block.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Output shape matches input
  • d_GELU/dx at x=0 is 0.5 (sanity)
  • Diagnostic: matches finite-difference numerical gradient
  • grad_out is broadcast (multiplied) into the result
  • Large positive x: d_GELU/dx -> 1 (looks like x for large x, slope -> 1)
  • Large negative x: d_GELU/dx -> 0 (saturates to 0)