BatchNorm forward (train + eval modes)Medium

BatchNorm forward (train + eval modes)

Background

BatchNorm (Ioffe & Szegedy, 2015) was the first normalisation layer, and it stabilised the training of deep networks by keeping each feature's activations at roughly zero mean and unit variance. It normalises across the batch axis — every feature channel gets its own mean and variance computed over all examples in the batch. It has two distinct behaviours: a train mode that uses the current batch's statistics (and updates a running average for later), and an eval mode that uses those saved running statistics so inference is deterministic.

Problem statement

Implement batch_norm_forward(x, gamma, beta, running_mean, running_var, momentum, eps, training) for a (B, C) input, normalising along the batch axis (axis=0) per feature.

Train mode (training=True\texttt{training=True}) — normalise with batch statistics and update the running buffers in place:

μc=1Bi=1Bxic,σc2=1Bi=1B(xicμc)2x^ic=xicμcσc2+ϵ,outic=γcx^ic+βc\begin{aligned} \mu_c &= \frac{1}{B}\sum_{i=1}^{B} x_{ic}, \qquad \sigma_c^2 = \frac{1}{B}\sum_{i=1}^{B}\left(x_{ic}-\mu_c\right)^2 \\ \hat{x}_{ic} &= \frac{x_{ic}-\mu_c}{\sqrt{\sigma_c^2 + \epsilon}}, \qquad \text{out}_{ic} = \gamma_c\,\hat{x}_{ic} + \beta_c \end{aligned}

with running buffers updated by an exponential moving average of weight ρ=\rho = momentum on the new batch:

μrun(1ρ)μrun+ρμ,σrun2(1ρ)σrun2+ρσ2\mu^{\text{run}} \leftarrow (1-\rho)\,\mu^{\text{run}} + \rho\,\mu, \qquad \sigma^2_{\text{run}} \leftarrow (1-\rho)\,\sigma^2_{\text{run}} + \rho\,\sigma^2

Eval mode (training=False\texttt{training=False}) — normalise with the saved running statistics and do not update them:

x^=xμrunσrun2+ϵ,out=γx^+β\hat{x} = \frac{x - \mu^{\text{run}}}{\sqrt{\sigma^2_{\text{run}} + \epsilon}}, \qquad \text{out} = \gamma\,\hat{x} + \beta

Input

  • xnp.ndarray of shape (B, C): B examples, C feature channels.
  • gamma(C,), the learned per-feature scale γ\gamma.
  • beta(C,), the learned per-feature shift β\beta.
  • running_mean(C,), the running mean buffer. Mutated in place when training=True.
  • running_var(C,), the running variance buffer. Mutated in place when training=True.
  • momentum — float in [0,1][0, 1], the EMA weight ρ\rho on the new batch stats (PyTorch default 0.1).
  • eps — float, numerical-stability constant added inside the square root.
  • training — bool, True for train mode, False for eval mode.

Output

Returns an np.ndarray of the same shape as x. In train mode it also mutates running_mean and running_var in place; in eval mode those buffers are left untouched.

Examples

Example 1 — train mode (B=2, C=1)

Input:  x=[[0.0], [4.0]], gamma=[1.0], beta=[0.0],
        running_mean=[0.0], running_var=[1.0],
        momentum=0.1, eps=0, training=True
Output: out=[[-1.0], [1.0]]
        running_mean -> [0.2],  running_var -> [1.3]

Explanation: batch stats along axis=0 are μ=2\mu = 2 and σ2=12((02)2+(42)2)=4\sigma^2 = \tfrac{1}{2}\big((0-2)^2 + (4-2)^2\big) = 4 (biased, divided by BB). So x^=(x2)/4\hat{x} = (x-2)/\sqrt{4}, giving out [1,1][-1, 1]. The running buffers then move 10% toward the batch: μrun=0.90+0.12=0.2\mu^{\text{run}} = 0.9\cdot 0 + 0.1\cdot 2 = 0.2 and σrun2=0.91+0.14=1.3\sigma^2_{\text{run}} = 0.9\cdot 1 + 0.1\cdot 4 = 1.3.

Example 2 — eval mode (B=2, C=2)

Input:  x=[[10.0, 20.0], [30.0, 40.0]], gamma=[1.0, 1.0], beta=[0.0, 0.0],
        running_mean=[5.0, 5.0], running_var=[4.0, 4.0],
        momentum=0.1, eps=1e-9, training=False
Output: [[2.5, 7.5], [12.5, 17.5]]   (running_mean / running_var unchanged)

Explanation: eval ignores the batch and uses the saved stats: (xμrun)/σrun2=(x5)/4=(x5)/2(x - \mu^{\text{run}})/\sqrt{\sigma^2_{\text{run}}} = (x-5)/\sqrt{4} = (x-5)/2. The running buffers are not modified.

Constraints

  • x is (B, C); statistics are computed along axis=0 (the batch axis), so each of the C features has its own μ\mu and σ2\sigma^2.
  • Use the biased variance (divide by BB, not B1B-1) so it matches the normalisation denominator — this is np.var's default.
  • Train mode mutates running_mean and running_var in place (*=, +=); eval mode must leave them exactly unchanged.
  • momentum weights the new batch stats: (1ρ)old+ρnew(1-\rho)\cdot\text{old} + \rho\cdot\text{new}.
  • ϵ\epsilon is added inside the square root: σ2+ϵ\sqrt{\sigma^2 + \epsilon}.
  • Tests compare with tolerances from atol=1e-6 (means) to atol=1e-3 (std).

Notes

  • Deterministic inference. Eval normalises with the saved running stats, never the eval batch, so a single example yields the same output regardless of what else shares its batch.
  • Why BatchNorm breaks for autoregressive LMs. At generation you produce one token at a time, so the "eval batch" is size 1 and batch statistics are degenerate. That is why GPT-2 and friends use LayerNorm — which has no batch axis and normalises along axis=-1 per token — instead.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Train mode: output shape matches input
  • Diagnostic: train mode normalises each feature to mean ~0, std ~1
  • Train mode mutates running_mean and running_var
  • Eval mode does NOT mutate running stats and uses them for normalisation
  • gamma scales, beta shifts (same as LayerNorm)