Affine coupling layer (normalizing flow)Medium

Affine coupling layer (normalizing flow)

Background

Normalizing flows transform a simple distribution into a complex one through a chain of invertible maps. The affine coupling layer (RealNVP) is the key building block: it splits the input in two halves, leaves the first half unchanged, and transforms the second half with a scale-and-shift conditioned on the first. Because the first half is untouched, the layer is trivially invertible and its Jacobian is triangular — so the log-determinant is just the sum of the log-scales.

Problem statement

Implement affine_coupling_forward(x, s, t) for a coupling layer. Split x into x1, x2 (first/second half along the last axis). With scale s and shift t (functions of x1, here passed precomputed):

y1=x1,y2=x2es+t,logdetJ=sy_1 = x_1, \qquad y_2 = x_2 \odot e^{s} + t, \qquad \log|\det J| = \sum s

Return (y, log_det) where y is the concatenation of y1, y2.

Input

  • xnp.ndarray (batch, d) with d even.
  • s — scale (log-scale) for the second half, shape (batch, d/2).
  • t — shift for the second half, shape (batch, d/2).

Output

A tuple (y, log_det):

  • ynp.ndarray (batch, d), the transformed output.
  • log_detnp.ndarray (batch,), the per-example log-determinant sum(s, axis=-1).

Examples

Example 1

Input:  x = [[1, 2, 3, 4]], s = [[0, 0]], t = [[0, 0]]
Output: y = [[1, 2, 3, 4]], log_det = [0]

Explanation: with s=0 the scale is e0=1e^0=1 and t=0, so the second half is unchanged: y2=x2y_2 = x_2. The log-determinant is s=0\sum s = 0 (identity transform).

Constraints

  • The first half passes through unchanged; only the second half is scaled and shifted.
  • The scale is applied as ese^{s} (s is the log-scale).
  • log_det is the sum of s over the feature axis (per example).

Notes

  • Invertibility is immediate: given yy, recover x2=(y2t)esx_2 = (y_2 - t)\,e^{-s} using y1=x1y_1=x_1 to recompute s,ts,t.
  • Stacking coupling layers (and swapping which half is transformed) lets the flow model arbitrary densities while keeping exact, cheap likelihoods.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Identity transform when s=0, t=0
  • First half is unchanged
  • Second half is scaled and shifted
  • log_det is the sum of s
  • Invertible: recover x2 from y