From bf9a0609dc38a709c57cb731e11466ffd0bc247c Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 13 Dec 2024 16:19:47 +0000 Subject: [PATCH] remove masking support for sigmoid_focal_loss --- examples/mlperf/losses.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 9bf4d2cfd4..de9768f1b5 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -8,7 +8,7 @@ def dice_ce_loss(pred, tgt): dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean() return (dice + ce) / 2 -def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor: +def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor: assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") p_t = p * tgt + (1 - p) * (1 - tgt) @@ -18,13 +18,6 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, al alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) loss *= alpha_t - if mask is not None: loss *= mask - - if reduction == "mean": - if mask is not None: - loss = loss.sum(axis=0).mean() / mask.sum() - else: - loss = loss.mean() - elif reduction == "sum": loss = loss.sum(axis=0).sum() - + if reduction == "mean": loss = loss.mean() + elif reduction == "sum": loss = loss.sum() return loss