SwiGLU activationMedium

SwiGLU activation

Background

SwiGLU is the gated activation in the feed-forward block of modern LLMs (PaLM, LLaMA). A GLU (gated linear unit) splits its input in half and lets one half gate the other; SwiGLU uses the Swish gate swish(z)=zσ(z)\text{swish}(z) = z\,\sigma(z). It consistently outperforms plain ReLU/GELU feed-forward layers at the same parameter budget.

Problem statement

Implement SwiGLU(x) for input of shape (batch, 2d). Split the last dimension into halves x1,x2x_1, x_2 and return:

SwiGLU(x)=x1swish(x2)=x1(x2σ(x2))\text{SwiGLU}(x) = x_1 \odot \text{swish}(x_2) = x_1 \odot \big(x_2 \cdot \sigma(x_2)\big)

where σ\sigma is the sigmoid.

Input

  • xnp.ndarray of shape (batch_size, 2d).

Output

Returns an np.ndarray of shape (batch_size, d).

Examples

Example 1

Input:  [[1, -1, 1000, -1000]]
Output: [[1000.0, 0.0]]

Explanation: split into x1=[1,1]x_1=[1,-1] and x2=[1000,1000]x_2=[1000,-1000]. σ(1000)1\sigma(1000)\approx1 so swish(1000)1000\text{swish}(1000)\approx1000; σ(1000)0\sigma(-1000)\approx0 so swish(1000)0\text{swish}(-1000)\approx0. Then x1swish(x2)=[11000,10]=[1000,0]x_1\odot\text{swish}(x_2)=[1\cdot1000,\,-1\cdot0]=[1000, 0].

Constraints

  • The last dimension is even; split it into equal halves x1x_1 (first) and x2x_2 (second).
  • Gate with swish(x2)=x2σ(x2)\text{swish}(x_2) = x_2\,\sigma(x_2); output =x1swish(x2)= x_1 \odot \text{swish}(x_2).
  • The output has half the last-dimension size: (batch, d).

Notes

  • The GLU family x1g(x2)x_1 \odot g(x_2) lets the network learn which features to pass through — Swish is the smooth, non-monotonic gate that works best empirically.
  • In a real FFN, x1x_1 and x2x_2 come from two separate linear projections of the input; here they are handed to you pre-projected as the two halves.
Python
Loading...

This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example
  • Output halves the last dimension
  • Matches x1 * (x2 * sigmoid(x2))
  • Finite for extreme inputs