AdamW step (decoupled weight decay)
Background
AdamW is Adam with decoupled weight decay — the optimiser behind GPT-2, GPT-3, Llama, and essentially every modern transformer training run. It computes the same adaptive moment update as Adam, then applies an L2-style shrinkage term directly to the parameters instead of folding it into the gradient. That one change (Loshchilov & Hutter, 2017) is what makes weight decay regularise as intended in adaptive optimisers.
Problem statement
Implement adamw_step(params, grads, m, v, t, lr, beta1, beta2, eps, weight_decay) that applies one AdamW update, in place, to every parameter tensor. For each parameter with gradient , moment buffers , and weight-decay coefficient :
The decay term sits outside the adaptive scaling — it is not added to before the moments are computed.
Input
params—list[np.ndarray], the model parameters. Mutated in place.grads—list[np.ndarray], same shapes asparams; the gradient for each.m—list[np.ndarray], first-moment buffers (start at 0). Mutated in place.v—list[np.ndarray], second-moment buffers (start at 0). Mutated in place.t—int, , the current step (1-indexed). Drives bias correction.lr, beta1, beta2, eps— scalars: learning rate, first/second-moment decay rates, stabiliser .weight_decay— scalar , the decoupled decay coefficient.weight_decay=0is plain Adam.
Output
Returns None. Updates params, m, and v in place, so the caller's references see the new values after the call.
Examples
Example 1 — decay shrinks a parameter even when the gradient is zero
Input: params=[1.0], grads=[0.0], m=[0.0], v=[0.0],
t=1, lr=0.1, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.5
Output: params≈[0.95]
Explanation: with the Adam term is 0, so only the decay acts: . Plain Adam would leave unchanged — this shrinkage toward 0 is the weight decay.
Example 2 — the decay is decoupled (moments never see it)
Input: params=[1.0], grads=[1.0], m=[0.0], v=[0.0],
t=1, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.1
Output: m=[0.1] (= (1-β₁)·g, NOT (1-β₁)·(g + λ·p))
Explanation: the first moment is . The decay term is added straight to the parameter step, never to , so it never enters or . The wrong "Adam-with-L2" form would instead give .
Constraints
- Within each tuple,
params[i],grads[i],m[i],v[i]share the same shape. - and is the true step count — bias correction divides by .
- Update in place (
*=,+=,-=). - is added after the square root: .
- Weight decay is decoupled: apply to the parameter update directly; do not add to before computing .
weight_decay=0recovers plain Adam exactly. - Tests compare with tolerance
atol ≈ 1e-9.
Notes
- Decoupled vs L2. Folding into routes the decay through the adaptive scaling, so parameters with large gradients (which need more regularisation) receive less decay — backwards. AdamW keeps decay outside that scaling, restoring the behaviour everyone expects from weight decay.
- State persistence.
mandvare optimiser state reused across steps, so the in-place update is part of the contract — it builds directly on the Adam optimiser step.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •weight_decay=0 reduces to plain Adam
- •Diagnostic: weight decay shrinks parameters even when grad is 0
- •Diagnostic: WD is decoupled — not folded into gradient before moments
- •Multiple parameter tensors all see the same decay rate
- •Mutation is in place — caller sees updated values