Noisy top-k gatingMedium

Noisy top-k gating

Background

The noisy top-k gating function (Shazeer et al., 2017) is the router that makes sparse mixture-of-experts trainable. To each expert's gate logit it adds tunable Gaussian noise scaled by a learned, per-expert softplus term. The noise encourages exploration and load balancing across experts during training. Only the top-k noisy logits survive; the rest are masked to -\infty before a softmax, producing a sparse gate distribution.

Problem statement

Implement noisy_topk_gating(X, W_g, W_noise, N, k):

H=XWg+Nsoftplus(XWnoise),softplus(z)=log(1+ez)H = X W_g + \mathcal N \odot \operatorname{softplus}(X W_{\text{noise}}), \qquad \operatorname{softplus}(z)=\log(1+e^{z})

For each row, keep the top-k entries of HH and set the rest to -\infty, then apply a softmax across experts.

Input

  • Xnp.ndarray (n_tokens, d_model).
  • W_g — gate weights (d_model, n_experts).
  • W_noise — noise-scale weights (d_model, n_experts).
  • N — precomputed noise samples (n_tokens, n_experts) (passed in for determinism).
  • kint, number of experts to keep per token.

Output

An np.ndarray (n_tokens, n_experts): each row is a sparse softmax with at most k nonzero entries summing to 1.

Examples

Example 1

Input:  X = [[1, 2]], W_g = [[1,0],[0,1]], W_noise = [[0.5,0.5],[0.5,0.5]], N = [[1,-1]], k = 2
Output: [[0.917, 0.083]]

Explanation: base logits [1,2][1,2]; softplus of the noise pre-activations [1.5,1.5][1.5,1.5] is 1.701\approx1.701; with noise [+1,1][+1,-1] the logits become [2.701,0.299][2.701, 0.299]. Softmax over both gives [0.917,0.083]\approx[0.917, 0.083].

Constraints

  • Add the noise as N * softplus(X @ W_noise) (elementwise), not raw X @ W_noise.
  • Mask all but the top-k entries per row to -\infty before the softmax.
  • Each output row sums to 1 with at most k nonzero entries.

Notes

  • Because the dropped logits become -\infty, their softmax weight is exactly 0 — the gate is genuinely sparse, not just small.
  • During training the noise term is what spreads tokens across experts; at inference it is usually turned off (N = 0).
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • Each row sums to 1
  • top_k = 1 yields a one-hot gate
  • Zero noise with k = n_experts is plain softmax of X @ W_g
  • Output shape is (n_tokens, n_experts)