He weight initializationEasy

He weight initialization

Background

Weight initialization sets the starting scale of a network's parameters. If weights are too large or too small, signals (and gradients) explode or vanish as they pass through layers. He initialization (Kaiming, 2015) is designed for ReLU networks: it draws weights from a zero-mean normal whose variance is 2/nin2/n_{\text{in}}, which keeps the variance of activations roughly constant across layers.

Problem statement

Implement he_init(fan_in, fan_out, seed=42) that returns a weight matrix of shape (fan_in, fan_out) with entries drawn from:

WijN ⁣(0, 2nin)W_{ij} \sim \mathcal{N}\!\left(0,\ \frac{2}{n_{\text{in}}}\right)

Use np.random.default_rng(seed) and scale standard normal samples by 2/nin\sqrt{2/n_{\text{in}}}.

Input

  • fan_inint, number of input units ninn_{\text{in}} (rows).
  • fan_outint, number of output units (columns).
  • seedint, RNG seed for reproducibility.

Output

An np.ndarray of shape (fan_in, fan_out), each entry N(0,2/nin)\sim \mathcal{N}(0, 2/n_{\text{in}}).

Examples

Example 1

Input:  fan_in = 4, fan_out = 3, seed = 0
Output: array of shape (4, 3); empirical std approximately sqrt(2/4) = 0.707

Explanation: each weight is a standard normal sample multiplied by 2/4=0.7071\sqrt{2/4}=0.7071, so the entries have mean 0\approx 0 and standard deviation 0.707\approx 0.707.

Constraints

  • The scaling factor is 2/nin\sqrt{2/n_{\text{in}}} (the "2" is what distinguishes He from Xavier/Glorot, which uses 1 or 2/(nin+nout)2/(n_{in}+n_{out})).
  • Use np.random.default_rng(seed) so results are reproducible.
  • Return a real-valued array of shape (fan_in, fan_out).

Notes

  • The factor 2 compensates for ReLU zeroing out half its inputs (halving the variance) — for tanh you would use Xavier's factor of 1 instead.
  • Initializing all weights to the same constant breaks symmetry: every neuron in a layer would learn the same thing, so random init is essential.
Python
Loading...

This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Correct output shape
  • Reproducible with the same seed
  • Empirical std matches sqrt(2/fan_in)
  • Mean is approximately zero