Adam optimiser stepMedium

Adam optimiser step

Background

Adam (Adaptive Moment Estimation) is the default optimiser for training deep networks and transformers. It keeps a per-parameter running estimate of the gradient's first moment (the mean, mm) and second moment (the uncentered variance, vv), then scales each parameter's step by those estimates — so parameters with consistently large or noisy gradients get smaller, better-conditioned updates. It is the optimiser behind most LLM pre-training; AdamW, its weight-decay variant, differs by a single term. Knowing it cold is table-stakes for an ML interview.

Problem statement

Implement adam_step(params, grads, m, v, t, lr, beta1, beta2, eps) that applies one Adam update, in place, to every parameter tensor. For each parameter pp with gradient gg and moment buffers m,vm, v:

mβ1m+(1β1)gvβ2v+(1β2)g2m^=m1β1t,v^=v1β2tpplrm^v^+ϵ\begin{aligned} m &\leftarrow \beta_1\, m + (1-\beta_1)\, g \\ v &\leftarrow \beta_2\, v + (1-\beta_2)\, g^2 \\ \hat{m} &= \frac{m}{1-\beta_1^{\,t}}, \qquad \hat{v} = \frac{v}{1-\beta_2^{\,t}} \\ p &\leftarrow p - \mathrm{lr}\cdot \frac{\hat{m}}{\sqrt{\hat{v}} + \epsilon} \end{aligned}

Defaults from the original paper: lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8.

Input

  • paramslist[np.ndarray], the model parameters. Mutated in place.
  • gradslist[np.ndarray], same shapes as params; the gradient gg for each.
  • mlist[np.ndarray], first-moment buffers (start at 0). Mutated in place.
  • vlist[np.ndarray], second-moment buffers (start at 0). Mutated in place.
  • tint, t1t \ge 1, the current step (1-indexed). Drives bias correction.
  • lr, beta1, beta2, eps — scalars: learning rate, first/second-moment decay rates, and the stabiliser ϵ\epsilon.

Output

Returns None. Updates params, m, and v in place, so the caller's references see the new values after the call.

Examples

Example 1 — first step from zero state

Input:  params=[1.0], grads=[0.1], m=[0.0], v=[0.0],
        t=1, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8
Output: m=[0.01], v=[1e-5], params≈[0.999]

Explanation: m=0.10.1=0.01m = 0.1\cdot0.1 = 0.01, so m^=0.01/(10.9)=0.1\hat{m} = 0.01/(1-0.9) = 0.1. v=0.0010.01=105v = 0.001\cdot0.01 = 10^{-5}, so v^=105/(10.999)=0.01\hat{v} = 10^{-5}/(1-0.999) = 0.01. Then pp1030.1/0.01=p103p \leftarrow p - 10^{-3}\cdot 0.1/\sqrt{0.01} = p - 10^{-3}, giving p0.999p \approx 0.999.

Example 2 — bias correction at t=1t=1

Input:  params=[10.0], grads=[2.0], m=[0.0], v=[0.0],
        t=1, lr=1.0, beta1=0.9, beta2=0.999, eps=1e-12
Output: params≈[9.0]

Explanation: at t=1t=1 the (1βt)(1-\beta^t) denominator exactly cancels the (1β)(1-\beta) numerator, so m^=g=2\hat{m}=g=2 and v^=g2=4\hat{v}=g^2=4. The step is lr2/4=1.0\mathrm{lr}\cdot 2/\sqrt{4} = 1.0, so p9.0p \approx 9.0. The first step is lrsign(g)\approx \mathrm{lr}\cdot\operatorname{sign}(g) regardless of β\beta — this is what keeps Adam from stalling at step 1.

Constraints

  • Within each tuple, params[i], grads[i], m[i], v[i] share the same shape.
  • t1t \ge 1 and is the true step count — bias correction divides by 1βt1-\beta^{\,t}.
  • Update in place (*=, +=, -=). Rebinding m/v discards the optimiser state across steps.
  • ϵ\epsilon is added after the square root: v^+ϵ\sqrt{\hat{v}} + \epsilon.
  • Tests compare with tolerance atol ≈ 1e-6.

Notes

  • Bias correction. m and v are initialised to 0, biasing them toward 0 early on; dividing by 1βt1-\beta^{\,t} corrects this and vanishes as tt grows.
  • State persistence. m and v are optimiser state that lives across steps — the in-place requirement is what lets a training loop reuse them.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • First step from zero state matches the bias-corrected formula
  • Diagnostic: at step 1, m_hat = g and v_hat = g^2 (bias correction effect)
  • Constant gradient over many steps: parameter drifts steadily
  • Multiple parameter tensors all update in their own m and v buffers
  • Mutation is in place — caller's references see the new values
  • eps prevents division-by-zero when v_hat is exactly 0