Sparse mixture-of-experts layer
Background
A sparse mixture-of-experts (MoE) layer replaces a single dense feed-forward block with many "expert" sub-networks plus a lightweight gating network. For each token the gate scores all experts, but only the top-k are actually run; their outputs are combined with the (renormalized) gate weights. This decouples model capacity (many experts) from per-token compute (only k run), which is how models like Mixtral scale parameter count cheaply.
Problem statement
Implement moe(x, We, Wg, n_experts, top_k):
- Flatten
xto(n_tokens, d_model). - Gating:
logits = x_flat @ Wg, thensoftmaxover experts. - Select the top-k experts per token and renormalize their gate weights to sum to 1.
- Each selected expert applies
x @ We[i]; scale by the renormalized weight and scatter-add into the token's output. - Reshape back to
(n_batch, l_seq, d_model).
Input
x—np.ndarrayof shape(n_batch, l_seq, d_model).We— experts, shape(n_experts, d_model, d_model).Wg— gate weights, shape(d_model, n_experts).n_experts—int.top_k—int, experts kept per token.
Output
An np.ndarray of shape (n_batch, l_seq, d_model).
Examples
Example 1
Input: x = arange(12).reshape(2, 3, 2), We = ones((4, 2, 2)), Wg = ones((2, 4)), top_k = 1
Output: [[[1, 1], [5, 5], [9, 9]], [[13, 13], [17, 17], [21, 21]]]
Explanation: every gate logit is equal, so each token routes to one expert with weight 1. An all-ones expert maps a token to the sum of its features, e.g. [0,1] -> [1,1] and [10,11] -> [21,21].
Constraints
- Renormalize the top-k gate weights so they sum to 1 after selection.
- Only the selected experts contribute; combine via a weighted (scatter) sum.
- The output keeps the input's
(n_batch, l_seq, d_model)shape.
Notes
- Renormalizing after top-k is important: the dropped experts' probability mass would otherwise leak away, shrinking the output.
- Real MoE adds a load-balancing loss so the gate doesn't collapse onto a few favorite experts.
This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •Reference example
- •Identity experts with top_k=1 pass tokens through unchanged
- •Routing selects the higher-logit expert
- •top_k=2 blends experts with renormalized weights
- •Output keeps the input shape