Contrastive lossMedium

Contrastive loss

Background

Contrastive loss trains a Siamese embedding from pairs: similar pairs are pulled together while dissimilar pairs are pushed apart until they are at least a margin away. It is the pairwise cousin of triplet loss and the basis of Siamese networks for verification and similarity learning.

Problem statement

Implement contrastive_loss(x1, x2, label, margin=1.0). With Euclidean distance di=x1ix2id_i = \lVert x1_i - x2_i\rVert and label yiy_i (1 = similar, 0 = dissimilar):

L=1Ni[yidi2+(1yi)max(0,margindi)2]\mathcal{L} = \frac{1}{N}\sum_i \Big[\, y_i\, d_i^2 + (1 - y_i)\,\max(0,\, \text{margin} - d_i)^2 \,\Big]

Input

  • x1, x2np.ndarray (N, D): paired embeddings.
  • labelnp.ndarray (N,): 1 for similar pairs, 0 for dissimilar.
  • marginfloat: the minimum desired distance for dissimilar pairs.

Output

Returns a float: the mean contrastive loss (0\ge 0).

Examples

Example 1

Input:  x1 = [[0,0],[0,0]], x2 = [[1,0],[0.5,0]], label = [1, 0], margin = 1.0
Output: 0.625

Explanation: the similar pair (y=1y=1) has d=1d=1, contributing 112=11\cdot 1^2 = 1. The dissimilar pair (y=0y=0) has d=0.5<1d=0.5 < 1, contributing max(0,10.5)2=0.25\max(0, 1-0.5)^2 = 0.25. Mean =(1+0.25)/2=0.625= (1+0.25)/2 = 0.625.

Constraints

  • dd is the (non-squared) Euclidean distance; the similar term squares dd, the dissimilar term squares the hinged margin gap.
  • Dissimilar pairs already farther than margin contribute 0.
  • Average over the batch; the result is 0\ge 0.

Notes

  • Similar pairs pull quadratically with d2d^2 (no margin); dissimilar pairs push only while inside the margin — once separated they stop contributing.
  • This is the loss behind classic Siamese signature/face verification (Hadsell, Chopra & LeCun, 2006).
Python
Loading...

This problem ships 4 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • Reference example: 0.625
  • Identical similar pair -> 0
  • Dissimilar pair beyond the margin -> 0
  • Coincident dissimilar pair -> margin squared