Sparse mixture-of-experts layerHard

Sparse mixture-of-experts layer

Background

A sparse mixture-of-experts (MoE) layer replaces a single dense feed-forward block with many "expert" sub-networks plus a lightweight gating network. For each token the gate scores all experts, but only the top-k are actually run; their outputs are combined with the (renormalized) gate weights. This decouples model capacity (many experts) from per-token compute (only k run), which is how models like Mixtral scale parameter count cheaply.

Problem statement

Implement moe(x, We, Wg, n_experts, top_k):

  1. Flatten x to (n_tokens, d_model).
  2. Gating: logits = x_flat @ Wg, then softmax over experts.
  3. Select the top-k experts per token and renormalize their gate weights to sum to 1.
  4. Each selected expert applies x @ We[i]; scale by the renormalized weight and scatter-add into the token's output.
  5. Reshape back to (n_batch, l_seq, d_model).

Input

  • xnp.ndarray of shape (n_batch, l_seq, d_model).
  • We — experts, shape (n_experts, d_model, d_model).
  • Wg — gate weights, shape (d_model, n_experts).
  • n_expertsint.
  • top_kint, experts kept per token.

Output

An np.ndarray of shape (n_batch, l_seq, d_model).

Examples

Example 1

Input:  x = arange(12).reshape(2, 3, 2), We = ones((4, 2, 2)), Wg = ones((2, 4)), top_k = 1
Output: [[[1, 1], [5, 5], [9, 9]], [[13, 13], [17, 17], [21, 21]]]

Explanation: every gate logit is equal, so each token routes to one expert with weight 1. An all-ones expert maps a token to the sum of its features, e.g. [0,1] -> [1,1] and [10,11] -> [21,21].

Constraints

  • Renormalize the top-k gate weights so they sum to 1 after selection.
  • Only the selected experts contribute; combine via a weighted (scatter) sum.
  • The output keeps the input's (n_batch, l_seq, d_model) shape.

Notes

  • Renormalizing after top-k is important: the dropped experts' probability mass would otherwise leak away, shrinking the output.
  • Real MoE adds a load-balancing loss so the gate doesn't collapse onto a few favorite experts.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • Identity experts with top_k=1 pass tokens through unchanged
  • Routing selects the higher-logit expert
  • top_k=2 blends experts with renormalized weights
  • Output keeps the input shape