DPO lossMedium

DPO loss

Background

Direct Preference Optimization (DPO) aligns a language model to human preferences without training a separate reward model or running RL. Given pairs where a "chosen" response ywy_w is preferred over a "rejected" one yly_l, DPO directly increases the policy's relative log-probability of ywy_w over yly_l, anchored to a frozen reference model so the policy does not drift too far.

Problem statement

Implement dpo_loss(policy_chosen, policy_rejected, ref_chosen, ref_rejected, beta=0.1). With per-example sequence log-probabilities, compute:

Δ=(logπθ(yw)logπref(yw))(logπθ(yl)logπref(yl))\Delta = \big(\log\pi_\theta(y_w) - \log\pi_{\text{ref}}(y_w)\big) - \big(\log\pi_\theta(y_l) - \log\pi_{\text{ref}}(y_l)\big) L=1Nlogσ(βΔ)\mathcal L = -\,\frac{1}{N}\sum \log \sigma(\beta\,\Delta)

Return the mean loss over the batch.

Input

  • policy_chosen, policy_rejectednp.ndarray (N,), policy log-probs of the chosen/rejected responses.
  • ref_chosen, ref_rejectednp.ndarray (N,), reference-model log-probs.
  • betafloat, the KL/strength coefficient.

Output

A scalar float: the mean DPO loss.

Examples

Example 1

Input:  policy_chosen=[-1.0], policy_rejected=[-2.0], ref_chosen=[-1.5], ref_rejected=[-2.5], beta=0.1
Output: 0.6931  (= -log sigma(0))

Explanation: the policy log-ratio is 1(2)=1-1-(-2)=1 and the reference log-ratio is 1.5(2.5)=1-1.5-(-2.5)=1, so Δ=0\Delta=0. With σ(0)=0.5\sigma(0)=0.5, the loss is log0.5=0.6931-\log 0.5 = 0.6931.

Constraints

  • Δ\Delta subtracts the reference log-ratio from the policy log-ratio (the implicit reward difference).
  • Use a numerically stable log-sigmoid (e.g. -softplus(-x)), not log(sigmoid(x)) directly.
  • Return the mean over the batch.

Notes

  • Minimizing the loss pushes Δ\Delta positive — the policy prefers ywy_w over yly_l more strongly than the reference does.
  • β\beta controls how far the policy may move from the reference; larger β\beta makes the loss more sensitive to the log-ratio gap.
Python
Loading...

This problem ships 5 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example -> log 2
  • Equal log-ratios give loss log(2)
  • A larger preference margin lowers the loss
  • Loss is the mean over the batch
  • Loss is always positive