remove masking support for sigmoid_focal_loss

This commit is contained in:
Francis Lata
2024-12-13 16:19:47 +00:00
parent bd68019bc8
commit bf9a0609dc

View File

@@ -8,7 +8,7 @@ def dice_ce_loss(pred, tgt):
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor:
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
p_t = p * tgt + (1 - p) * (1 - tgt)
@@ -18,13 +18,6 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, al
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss *= alpha_t
if mask is not None: loss *= mask
if reduction == "mean":
if mask is not None:
loss = loss.sum(axis=0).mean() / mask.sum()
else:
loss = loss.mean()
elif reduction == "sum": loss = loss.sum(axis=0).sum()
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss