diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 136071bdf8..c6e9044de5 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -8,7 +8,7 @@ def dice_ce_loss(pred, tgt): dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean() return (dice + ce) / 2 -def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor: +def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2.0, reduction:str = "none") -> Tensor: assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") p_t = p * tgt + (1 - p) * (1 - tgt) diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 37f3a31048..66186228b0 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -176,8 +176,9 @@ class RegressionHead: def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor: mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1) - tgt = self.box_coder.encode(bboxes, anchors) - loss = mask.where(l1_loss(x, tgt), 0).sum(-1).sum(-1) + x = x * mask + tgt = self.box_coder.encode(bboxes, anchors) * mask + loss = l1_loss(x, tgt).sum(-1).sum(-1) loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0] return loss