Focal loss (multiclass)Medium

Focal loss (multiclass)

Background

Focal loss (Lin et al., 2017, for RetinaNet object detection) solves a class-imbalance problem: when 99% of candidate boxes are "easy negatives" (obviously background), their many small losses swamp the gradient and drown out the rare hard examples. The fix is to down-weight examples the model already gets right, via a modulating factor on top of cross-entropy, so training keeps focusing on what it hasn't learned yet.

Problem statement

Implement focal_loss(logits, target, gamma=2.0, alpha=1.0) for one example. Let pt=softmax(z)targetp_t = \text{softmax}(z)_{\text{target}} be the predicted probability of the true class. Then:

Lfocal=α(1pt)γlogptL_{\text{focal}} = -\,\alpha\,(1 - p_t)^{\gamma}\,\log p_t

The factor (1pt)γ(1-p_t)^\gamma does the work: confident-correct (pt ⁣ ⁣1p_t\!\to\!1) examples have it 0\to 0 (loss crushed), while hard (ptp_t small) examples keep it 1\approx 1 (loss ≈ cross-entropy). Compute the softmax stably and clip ptp_t away from 0.

Input

  • logits — 1-D np.ndarray of shape (K,): the pre-softmax scores for one example.
  • targetint: the true class index in [0,K)[0, K).
  • gammafloat 0\ge 0: the focusing parameter (γ=0\gamma=0 reduces to weighted CE; γ=2\gamma=2 is the RetinaNet default).
  • alphafloat 0\ge 0: the class-balancing weight.

Output

Returns a float: the scalar focal loss.

Examples

Example 1 — γ=0\gamma = 0 reduces to (weighted) cross-entropy

Input:  logits = [1.0, 2.0, 3.0, 0.5], target = 2, gamma = 0, alpha = 1
Output: -log(p_t)        # plain cross-entropy

Explanation: with γ=0\gamma=0 the modulating factor (1pt)0=1(1-p_t)^0 = 1, so the loss collapses to αlogpt-\alpha\log p_t — exactly α\alpha-weighted cross-entropy.

Example 2 — the modulation factor scales the loss

Input:  logits = [0.0, 0.0, 0.0], target = 1, gamma = 2, alpha = 1
Output: ≈ 0.488

Explanation: uniform logits give pt=1/3p_t = 1/3, so the factor is (11/3)2=(2/3)20.444(1 - 1/3)^2 = (2/3)^2 \approx 0.444. The focal loss is that fraction of the plain CE log(1/3)1.0986-\log(1/3)\approx1.0986, i.e. 0.444×1.09860.4880.444\times1.0986\approx0.488.

Constraints

  • Compute ptp_t from a numerically-stable softmax (subtract logits.max()); clip ptp_t to a tiny floor (e.g. 1e-30) to avoid log(0).
  • L=α(1pt)γlogptL = -\alpha\,(1-p_t)^\gamma\log p_t. With γ=0\gamma=0 it equals α\alpha-weighted CE.
  • alpha scales the loss linearly (doubling alpha doubles the loss).
  • A confident-correct prediction yields a loss far below plain CE; an uncertain one (small ptp_t) yields roughly CE. Stable for large logits (e.g. [1000, 1001, 1002]).

Notes

  • Three regimes of (1pt)γ(1-p_t)^\gamma. pt1p_t\approx1 → factor 0\to 0 (the model "moves on"); pt0p_t\approx0 → factor 1\to 1 (full CE, reweighted by α\alpha); pt0.5p_t\approx0.5 → factor 0.5γ\approx 0.5^\gamma (moderately reduced).
  • Defaults & uses. γ=2\gamma=2, α=0.25\alpha=0.25 (for the rare class) is the canonical RetinaNet setting; focal loss is standard in object detection (RetinaNet, YOLO derivatives) and any severely imbalanced classifier.
Python
Loading...

This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.

  • gamma=0 reduces to alpha-weighted cross-entropy
  • Diagnostic: matches the explicit formula -alpha * (1-p_t)^gamma * log(p_t)
  • Confident-correct prediction has loss approaching 0 (focal effect)
  • Hard example (low p_t) has loss similar to CE — modulation factor ~1
  • alpha scales the loss linearly
  • Numerically stable for large logits