Softmax from scratchEasy

Softmax from scratch

Background

Softmax turns a vector of arbitrary real numbers into a probability distribution. It is the final layer of every classifier, the gate at the top of every attention head, and the operation that turns logits into next-token probabilities. Implementing it well is mostly about one thing: numerical stability. The naïve formula overflows on the large logits that real LLMs produce, and the fix is a one-line trick everyone is expected to know.

Problem statement

Implement softmax(x) for a 1-D array. The definition, and the shift trick that makes it stable:

softmax(x)i=exijexj=exicjexjc(for any constant c)\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}} \quad (\text{for any constant } c)

Pick c=max(x)c = \max(x) so the largest exponent is e0=1e^0 = 1 — no overflow. The output must be all-positive and sum to 1.

Input

  • x — 1-D np.ndarray of arbitrary real values.

Output

Returns a 1-D np.ndarray of the same shape, every element in (0,1)(0, 1), summing to 1.

Examples

Example 1 — a basic vector

Input:  x = [1.0, 2.0, 3.0]
Output: ≈ [0.0900, 0.2447, 0.6652]

Explanation: exponentiate and normalise. The gap of 1 between consecutive logits gives a ratio of e2.718e \approx 2.718 between consecutive probabilities, so the largest logit dominates while all three stay positive and sum to 1.

Example 2 — large logits, no overflow

Input:  x = [1000.0, 1001.0, 1002.0]
Output: ≈ [0.0900, 0.2447, 0.6652]   (identical to softmax([0,1,2]); no NaN/inf)

Explanation: naïve exp(1000) overflows to inf. Subtracting max(x)=1002\max(x)=1002 first turns the exponents into [e2,e1,e0][e^{-2}, e^{-1}, e^{0}] — finite, and identical to softmax([0,1,2]) because softmax is invariant to a constant shift.

Constraints

  • Every output is in (0,1)(0, 1) and the array sums to 1 (atol≈1e-8); equal logits give the uniform distribution 1/n1/n.
  • Subtract max(x) before exp — otherwise inputs in the thousands overflow to inf (and produce NaN after the division).
  • Must stay finite for both large-positive ([1000,1001,1002]) and very-negative ([-1000,-1001,-999]) inputs.

Notes

  • Why the shift is exact. softmax(x) == softmax(x - c) because the factor ece^{-c} cancels between numerator and denominator — so the stabilising subtraction changes nothing mathematically.
  • Reused everywhere. The same max-shift appears in cross-entropy gradient, attention, and label smoothing; its derivative is the softmax backward problem.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Basic 3-element vector matches reference
  • Output sums to 1
  • All outputs in (0, 1)
  • Uniform distribution for equal logits
  • Numerically stable for large logits (uses max-shift)
  • Numerically stable for very negative logits