From 60c13c293299cc5160b6cdd3feceeaa887aef8b0 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Sun, 23 Feb 2025 21:22:33 +0000 Subject: [PATCH] update loss calculation for regresionhead and some cleanups --- examples/mlperf/losses.py | 2 +- extra/models/retinanet.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 136071bdf8..c6e9044de5 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -8,7 +8,7 @@ def dice_ce_loss(pred, tgt): dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean() return (dice + ce) / 2 -def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor: +def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2.0, reduction:str = "none") -> Tensor: assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") p_t = p * tgt + (1 - p) * (1 - tgt) diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 37f3a31048..66186228b0 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -176,8 +176,9 @@ class RegressionHead: def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor: mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1) - tgt = self.box_coder.encode(bboxes, anchors) - loss = mask.where(l1_loss(x, tgt), 0).sum(-1).sum(-1) + x = x * mask + tgt = self.box_coder.encode(bboxes, anchors) * mask + loss = l1_loss(x, tgt).sum(-1).sum(-1) loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0] return loss