AdamW step (decoupled weight decay)Medium

AdamW step (decoupled weight decay)

Background

AdamW is Adam with decoupled weight decay — the optimiser behind GPT-2, GPT-3, Llama, and essentially every modern transformer training run. It computes the same adaptive moment update as Adam, then applies an L2-style shrinkage term directly to the parameters instead of folding it into the gradient. That one change (Loshchilov & Hutter, 2017) is what makes weight decay regularise as intended in adaptive optimisers.

Problem statement

Implement adamw_step(params, grads, m, v, t, lr, beta1, beta2, eps, weight_decay) that applies one AdamW update, in place, to every parameter tensor. For each parameter pp with gradient gg, moment buffers m,vm, v, and weight-decay coefficient λ\lambda:

mβ1m+(1β1)gvβ2v+(1β2)g2m^=m1β1t,v^=v1β2tpplr(m^v^+ϵ+λp)\begin{aligned} m &\leftarrow \beta_1\, m + (1-\beta_1)\, g \\ v &\leftarrow \beta_2\, v + (1-\beta_2)\, g^2 \\ \hat{m} &= \frac{m}{1-\beta_1^{\,t}}, \qquad \hat{v} = \frac{v}{1-\beta_2^{\,t}} \\ p &\leftarrow p - \mathrm{lr}\left(\frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon} + \lambda\, p\right) \end{aligned}

The decay term λp\lambda\, p sits outside the adaptive m^/v^\hat{m}/\sqrt{\hat{v}} scaling — it is not added to gg before the moments are computed.

Input

  • paramslist[np.ndarray], the model parameters. Mutated in place.
  • gradslist[np.ndarray], same shapes as params; the gradient gg for each.
  • mlist[np.ndarray], first-moment buffers (start at 0). Mutated in place.
  • vlist[np.ndarray], second-moment buffers (start at 0). Mutated in place.
  • tint, t1t \ge 1, the current step (1-indexed). Drives bias correction.
  • lr, beta1, beta2, eps — scalars: learning rate, first/second-moment decay rates, stabiliser ϵ\epsilon.
  • weight_decay — scalar λ0\lambda \ge 0, the decoupled decay coefficient. weight_decay=0 is plain Adam.

Output

Returns None. Updates params, m, and v in place, so the caller's references see the new values after the call.

Examples

Example 1 — decay shrinks a parameter even when the gradient is zero

Input:  params=[1.0], grads=[0.0], m=[0.0], v=[0.0],
        t=1, lr=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.5
Output: params≈[0.95]

Explanation: with g=0g=0 the Adam term is 0, so only the decay acts: pplrλp=10.10.51=0.95p \leftarrow p - \mathrm{lr}\,\lambda\, p = 1 - 0.1\cdot0.5\cdot1 = 0.95. Plain Adam would leave pp unchanged — this shrinkage toward 0 is the weight decay.

Example 2 — the decay is decoupled (moments never see it)

Input:  params=[1.0], grads=[1.0], m=[0.0], v=[0.0],
        t=1, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.1
Output: m=[0.1]   (= (1-β₁)·g, NOT (1-β₁)·(g + λ·p))

Explanation: the first moment is m=(1β1)g=0.11=0.1m = (1-\beta_1)g = 0.1\cdot1 = 0.1. The decay term is added straight to the parameter step, never to gg, so it never enters mm or vv. The wrong "Adam-with-L2" form would instead give m=(1β1)(g+λp)=0.11.1=0.11m = (1-\beta_1)(g + \lambda p) = 0.1\cdot1.1 = 0.11.

Constraints

  • Within each tuple, params[i], grads[i], m[i], v[i] share the same shape.
  • t1t \ge 1 and is the true step count — bias correction divides by 1βt1-\beta^{\,t}.
  • Update in place (*=, +=, -=).
  • ϵ\epsilon is added after the square root: v^+ϵ\sqrt{\hat{v}} + \epsilon.
  • Weight decay is decoupled: apply λp\lambda\, p to the parameter update directly; do not add λp\lambda\, p to gg before computing m,vm, v. weight_decay=0 recovers plain Adam exactly.
  • Tests compare with tolerance atol ≈ 1e-9.

Notes

  • Decoupled vs L2. Folding λp\lambda p into gg routes the decay through the adaptive 1/v^1/\sqrt{\hat{v}} scaling, so parameters with large gradients (which need more regularisation) receive less decay — backwards. AdamW keeps decay outside that scaling, restoring the behaviour everyone expects from weight decay.
  • State persistence. m and v are optimiser state reused across steps, so the in-place update is part of the contract — it builds directly on the Adam optimiser step.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • weight_decay=0 reduces to plain Adam
  • Diagnostic: weight decay shrinks parameters even when grad is 0
  • Diagnostic: WD is decoupled — not folded into gradient before moments
  • Multiple parameter tensors all see the same decay rate
  • Mutation is in place — caller sees updated values