Wasserstein distance (1-D)Medium

Wasserstein distance (1-D)

Background

The Wasserstein-1 distance (a.k.a. Earth Mover's Distance) measures how much "work" it takes to morph one distribution into another. Unlike KL divergence, it stays meaningful even when distributions don't overlap, which is why it underpins WGANs and is a robust way to compare empirical distributions. In 1-D it has a beautiful closed form: the area between the two cumulative distribution functions.

Problem statement

Implement wasserstein_1d(u, v) for two 1-D sample sets:

W1(u,v)=Fu(x)Fv(x)dxW_1(u, v) = \int_{-\infty}^{\infty} \big|F_u(x) - F_v(x)\big|\, dx

Compute it by merging and sorting all values, taking gaps Δx\Delta x between consecutive values, and summing FuFvΔx|F_u - F_v|\,\Delta x over those intervals (the empirical CDFs use searchsorted). This handles unequal sample sizes.

Input

  • u, v — 1-D np.ndarray of samples (sizes may differ).

Output

A float: the 1-D Wasserstein-1 distance.

Examples

Example 1

Input:  u = [1, 2, 3], v = [4, 5, 6]
Output: 3.0

Explanation: the two sets are a constant shift of 3 apart, so it costs 3 units to move every point of u onto v — the distance is 3.

Constraints

  • Build the merged sorted support; use Δx\Delta x between consecutive support points.
  • Empirical CDFs: searchsorted(sort(u), x, 'right') / len(u) (and likewise for v).
  • Sum FuFvΔx|F_u - F_v|\cdot\Delta x — equivalent to the area between the CDFs.

Notes

  • For equal-size samples this reduces to the mean absolute difference of the sorted values, mean(|sort(u) - sort(v)|).
  • Wasserstein gives nonzero, smoothly varying gradients between disjoint distributions, which is exactly why WGAN training is more stable than the original GAN.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Constant shift of 3
  • Identical samples give 0
  • Equal-size case equals mean absolute sorted difference
  • Symmetric in its arguments
  • Matches a hand-computed unequal-size value