Exponential moving average of weightsEasy

Exponential moving average of weights

Background

Keeping an exponential moving average (EMA) of model weights is a cheap trick that often improves final accuracy: alongside the weights being trained, you maintain a smoothed "shadow" copy that lags behind and averages out the noise of individual SGD steps. At evaluation time you use the EMA weights, which sit closer to the center of the loss basin. It is standard in diffusion models, self-supervised learning (e.g. BYOL / mean teacher), and many SOTA training recipes.

Problem statement

Implement ema_update(ema, weights, decay=0.99) that performs one EMA step:

emaβema+(1β)weights\text{ema} \leftarrow \beta \cdot \text{ema} + (1-\beta)\cdot \text{weights}

where β\beta is decay. Works on scalars or np.ndarray (elementwise).

Input

  • ema — current shadow value(s) (scalar or np.ndarray).
  • weights — current model weight value(s), same shape as ema.
  • decayfloat β[0,1]\beta \in [0, 1]; higher means slower, smoother tracking.

Output

The updated EMA value(s), same shape as the inputs.

Examples

Example 1

Input:  ema = 10.0, weights = 20.0, decay = 0.9
Output: 11.0

Explanation: 0.910+0.120=9+2=110.9\cdot 10 + 0.1\cdot 20 = 9 + 2 = 11. The shadow weight moves 10% of the way toward the current weight.

Constraints

  • The update is a convex combination: weight β\beta on the old EMA, 1β1-\beta on the new weights.
  • Support elementwise application over arrays.

Notes

  • A higher decay (e.g. 0.999) gives a slower, smoother average — it remembers more history; a lower decay tracks the live weights more closely.
  • The EMA has an effective window of roughly 1/(1β)1/(1-\beta) steps, so β=0.99\beta=0.99 averages over the last ~100 updates.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • decay = 1 keeps the EMA unchanged
  • decay = 0 snaps EMA to the current weights
  • Works elementwise on arrays
  • Converges toward a constant weight over many steps