VAE ELBO lossMedium

VAE ELBO loss

Background

A variational autoencoder (VAE) is trained to maximize the evidence lower bound (ELBO), equivalently to minimize a loss with two terms: a reconstruction term (how well the decoder rebuilds the input) and a KL term that pulls the encoder's latent distribution q(zx)=N(μ,σ2)q(z\mid x)=\mathcal N(\mu,\sigma^2) toward a standard-normal prior. The KL term is what regularizes the latent space so it can be sampled to generate new data.

Problem statement

Implement vae_elbo_loss(x, x_recon, mu, logvar) returning the mean negative-ELBO loss. Use squared-error reconstruction (summed over features) and the closed-form Gaussian KL:

reconi=j(xijx^ij)2,KLi=12j(1+logσj2μj2σj2)\text{recon}_i = \sum_j (x_{ij} - \hat x_{ij})^2, \qquad \text{KL}_i = -\tfrac12 \sum_j \big(1 + \log\sigma_j^2 - \mu_j^2 - \sigma_j^2\big) L=1Ni(reconi+KLi),σj2=elogvarj\mathcal L = \frac{1}{N}\sum_i \big(\text{recon}_i + \text{KL}_i\big), \qquad \sigma_j^2 = e^{\,\text{logvar}_j}

Input

  • x, x_reconnp.ndarray (N, d), input and its reconstruction.
  • mu, logvarnp.ndarray (N, d), the encoder's latent mean and log-variance.

Output

A scalar float: the mean loss over the batch.

Examples

Example 1

Input:  x = [[1, 0]], x_recon = [[1, 0]], mu = [[0, 0]], logvar = [[0, 0]]
Output: 0.0

Explanation: perfect reconstruction makes the SSE term 0; with μ=0\mu=0 and σ2=1\sigma^2=1 the posterior equals the prior, so KL is 12(1+001)=0-\tfrac12\sum(1+0-0-1)=0. The loss is 0.

Constraints

  • Convert logvar to variance with exp inside the KL.
  • Sum both terms over the feature axis per example, then average over the batch.
  • Return reconstruction plus KL (the quantity minimized = negative ELBO).

Notes

  • The KL is always 0\ge 0 and is 0 exactly when μ=0, σ2=1\mu=0,\ \sigma^2=1 — the posterior matches the prior.
  • Using logvar (rather than σ\sigma) keeps the variance positive and the optimization numerically stable.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Perfect reconstruction with prior-matching posterior is 0
  • KL term for mu=1, logvar=0 equals 1 (per latent dim)
  • Reconstruction term is summed squared error
  • KL is always non-negative
  • Loss is the batch mean