From aebccf93ac584abb59eccf0de4972d1a44860f08 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 21 Mar 2025 20:20:36 +0000 Subject: [PATCH] revert losses changes --- examples/mlperf/losses.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 7ce047bc70..d20aac063f 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -7,18 +7,18 @@ def dice_ce_loss(pred, tgt): return (dice + ce) / 2 def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float=0.25, gamma:float=2.0, reduction:str="none") -> Tensor: - assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" - p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") - p_t = p * tgt + (1 - p) * (1 - tgt) - loss = ce_loss * ((1 - p_t) ** gamma) + assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}" + p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none") + p_t = p * tgt + (1 - p) * (1 - tgt) + loss = ce_loss * ((1 - p_t) ** gamma) - if alpha >= 0: - alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) - loss = loss * alpha_t + if alpha >= 0: + alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt) + loss = loss * alpha_t - if reduction == "mean": loss = loss.mean() - elif reduction == "sum": loss = loss.sum() - return loss + if reduction == "mean": loss = loss.mean() + elif reduction == "sum": loss = loss.sum() + return loss def l1_loss(pred:Tensor, tgt:Tensor, reduction:str="none") -> Tensor: assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"