Warmup + cosine decay LR scheduleEasy

Warmup + cosine decay LR schedule

Background

This is the standard learning-rate schedule for modern transformer training — a linear warmup over the first few hundred steps, then a smooth cosine decay down to a floor, then a flat tail. GPT-2, GPT-3, Llama, and most public LLM runs use it. It is a pure function of the step: no state, no side effects, just the LR to use right now.

Problem statement

Implement lr_at(step, lr_max, lr_min, warmup_iters, max_iters) with three regimes:

lr(s)={lrmaxswarmups<warmuplrmin+12(lrmaxlrmin)(1+cos(πp))warmups<maxlrminsmaxp=swarmupmaxwarmup\text{lr}(s) = \begin{cases} \text{lr}_{\max}\cdot \dfrac{s}{\text{warmup}} & s < \text{warmup} \\[2mm] \text{lr}_{\min} + \tfrac12(\text{lr}_{\max}-\text{lr}_{\min})\big(1 + \cos(\pi\,p)\big) & \text{warmup} \le s < \text{max} \\[2mm] \text{lr}_{\min} & s \ge \text{max} \end{cases} \qquad p = \frac{s - \text{warmup}}{\text{max} - \text{warmup}}

At s=warmups=\text{warmup} both branches equal lrmax\text{lr}_{\max}; at s=maxs=\text{max} the cosine term gives cosπ=1\cos\pi = -1, so the LR is exactly lrmin\text{lr}_{\min}.

Input

  • stepint 0\ge 0: the current training step.
  • lr_max — peak LR, reached at step == warmup_iters.
  • lr_min — floor LR, reached at step == max_iters.
  • warmup_itersint 0\ge 0: duration of the linear warmup.
  • max_itersint: total training duration; after this, returns lr_min.

Output

Returns a float: the learning rate at this step.

Examples

Example 1 — the boundary points (lr_max=3e-4, lr_min=3e-5, warmup=100, max=1000)

step = 0      -> 0.0      # start of warmup
step = 100    -> 3e-4     # end of warmup == lr_max
step = 1000   -> 3e-5     # end of decay == lr_min
step = 5000   -> 3e-5     # flat tail past max_iters

Explanation: a linear ramp from 0 to lr_max over [0, warmup), a cosine decay from lr_max to lr_min over [warmup, max), then a constant lr_min.

Example 2 — the smooth midpoints

step = 50   (warmup/2)                  -> 0.5 * lr_max          (linear half-way)
step = 550  (warmup + (max-warmup)/2)   -> (lr_max + lr_min)/2   (cos(pi/2)=0)

Explanation: halfway through warmup the linear ramp is at half of lr_max; halfway through decay the cosine argument is π/2\pi/2, so cos=0\cos = 0 and the LR sits exactly between lr_max and lr_min.

Constraints

  • Three branches: linear warmup, cosine decay, then flat at lr_min.
  • p=(stepwarmup)/(maxwarmup)p = (\text{step} - \text{warmup}) / (\text{max} - \text{warmup}), so p=0p=0 gives lr_max and p=1p=1 gives lr_min.
  • The schedule is continuous at both boundaries (no jump at step==warmup, no cliff at step==max).
  • Pure function — same inputs always give the same output, with no side effects.

Notes

  • Why warmup. Adam's second-moment estimate is noisy in the first few hundred steps; hitting fragile freshly-initialised weights with lr_max can blow up training. Warmup ramps the LR while the optimiser stabilises.
  • Why cosine + flat tail. Cosine is smooth (differentiable) at both endpoints, avoiding restart-shock; the flat lr_min tail lets you train past max_iters without an abrupt LR cliff.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Step 0 (start of warmup): lr is exactly 0
  • End of warmup: lr is exactly lr_max
  • End of training: lr is exactly lr_min
  • Past max_iters: lr stays at lr_min (flat tail)
  • Diagnostic: midway through cosine decay, lr is at the half-amplitude point
  • Linear warmup is, well, linear