Focal loss (multiclass)
Background
Focal loss (Lin et al., 2017, for RetinaNet object detection) solves a class-imbalance problem: when 99% of candidate boxes are "easy negatives" (obviously background), their many small losses swamp the gradient and drown out the rare hard examples. The fix is to down-weight examples the model already gets right, via a modulating factor on top of cross-entropy, so training keeps focusing on what it hasn't learned yet.
Problem statement
Implement focal_loss(logits, target, gamma=2.0, alpha=1.0) for one example. Let be the predicted probability of the true class. Then:
The factor does the work: confident-correct () examples have it (loss crushed), while hard ( small) examples keep it (loss ≈ cross-entropy). Compute the softmax stably and clip away from 0.
Input
logits— 1-Dnp.ndarrayof shape(K,): the pre-softmax scores for one example.target—int: the true class index in .gamma—float: the focusing parameter ( reduces to weighted CE; is the RetinaNet default).alpha—float: the class-balancing weight.
Output
Returns a float: the scalar focal loss.
Examples
Example 1 — reduces to (weighted) cross-entropy
Input: logits = [1.0, 2.0, 3.0, 0.5], target = 2, gamma = 0, alpha = 1
Output: -log(p_t) # plain cross-entropy
Explanation: with the modulating factor , so the loss collapses to — exactly -weighted cross-entropy.
Example 2 — the modulation factor scales the loss
Input: logits = [0.0, 0.0, 0.0], target = 1, gamma = 2, alpha = 1
Output: ≈ 0.488
Explanation: uniform logits give , so the factor is . The focal loss is that fraction of the plain CE , i.e. .
Constraints
- Compute from a numerically-stable softmax (subtract
logits.max()); clip to a tiny floor (e.g.1e-30) to avoidlog(0). - . With it equals -weighted CE.
alphascales the loss linearly (doublingalphadoubles the loss).- A confident-correct prediction yields a loss far below plain CE; an uncertain one (small ) yields roughly CE. Stable for large logits (e.g.
[1000, 1001, 1002]).
Notes
- Three regimes of . → factor (the model "moves on"); → factor (full CE, reweighted by ); → factor (moderately reduced).
- Defaults & uses. , (for the rare class) is the canonical RetinaNet setting; focal loss is standard in object detection (RetinaNet, YOLO derivatives) and any severely imbalanced classifier.
This problem ships 6 hidden tests. They run in your browser via Pyodide — no backend, no submission queue. Press ▶ Run tests to execute.
- •gamma=0 reduces to alpha-weighted cross-entropy
- •Diagnostic: matches the explicit formula -alpha * (1-p_t)^gamma * log(p_t)
- •Confident-correct prediction has loss approaching 0 (focal effect)
- •Hard example (low p_t) has loss similar to CE — modulation factor ~1
- •alpha scales the loss linearly
- •Numerically stable for large logits