update loss calculation for regresionhead and some cleanups

This commit is contained in:
Francis Lata
2025-02-23 21:22:33 +00:00
parent 7dba815c47
commit 60c13c2932
2 changed files with 4 additions and 3 deletions

View File

@@ -8,7 +8,7 @@ def dice_ce_loss(pred, tgt):
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor:
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2.0, reduction:str = "none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
p_t = p * tgt + (1 - p) * (1 - tgt)

View File

@@ -176,8 +176,9 @@ class RegressionHead:
def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor:
mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1)
tgt = self.box_coder.encode(bboxes, anchors)
loss = mask.where(l1_loss(x, tgt), 0).sum(-1).sum(-1)
x = x * mask
tgt = self.box_coder.encode(bboxes, anchors) * mask
loss = l1_loss(x, tgt).sum(-1).sum(-1)
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
return loss