Noisy top-k gating
Background
The noisy top-k gating function (Shazeer et al., 2017) is the router that makes sparse mixture-of-experts trainable. To each expert's gate logit it adds tunable Gaussian noise scaled by a learned, per-expert softplus term. The noise encourages exploration and load balancing across experts during training. Only the top-k noisy logits survive; the rest are masked to before a softmax, producing a sparse gate distribution.
Problem statement
Implement noisy_topk_gating(X, W_g, W_noise, N, k):
For each row, keep the top-k entries of and set the rest to , then apply a softmax across experts.
Input
X—np.ndarray(n_tokens, d_model).W_g— gate weights(d_model, n_experts).W_noise— noise-scale weights(d_model, n_experts).N— precomputed noise samples(n_tokens, n_experts)(passed in for determinism).k—int, number of experts to keep per token.
Output
An np.ndarray (n_tokens, n_experts): each row is a sparse softmax with at most k nonzero entries summing to 1.
Examples
Example 1
Input: X = [[1, 2]], W_g = [[1,0],[0,1]], W_noise = [[0.5,0.5],[0.5,0.5]], N = [[1,-1]], k = 2
Output: [[0.917, 0.083]]
Explanation: base logits ; softplus of the noise pre-activations is ; with noise the logits become . Softmax over both gives .
Constraints
- Add the noise as
N * softplus(X @ W_noise)(elementwise), not rawX @ W_noise. - Mask all but the top-k entries per row to before the softmax.
- Each output row sums to 1 with at most
knonzero entries.
Notes
- Because the dropped logits become , their softmax weight is exactly 0 — the gate is genuinely sparse, not just small.
- During training the noise term is what spreads tokens across experts; at inference it is usually turned off (
N = 0).
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Reference example
- •Each row sums to 1
- •top_k = 1 yields a one-hot gate
- •Zero noise with k = n_experts is plain softmax of X @ W_g
- •Output shape is (n_tokens, n_experts)