add missing test and cleanup focal loss

This commit is contained in:
Francis Lata
2024-12-11 23:00:48 +00:00
parent 827b2114e2
commit 2214d13b3d
2 changed files with 8 additions and 3 deletions

View File

@@ -22,9 +22,9 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, al
if reduction == "mean":
if mask is not None:
loss = (avg_loss := loss.sum(0).mean()) / (Tensor.ones_like(avg_loss) if mask is None else mask.sum())
loss = loss.sum(axis=0).mean() / mask.sum()
else:
loss = loss.mean()
elif reduction == "sum": loss = loss.sum(0).sum()
elif reduction == "sum": loss = loss.sum(axis=0).sum()
return loss

View File

@@ -30,7 +30,7 @@ class TestSigmoidFocalLoss(unittest.TestCase):
ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs)
# NOTE: since boolean indexing is not supported in tinygrad, compare the sum instead.
np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), atol=1e-5)
np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), rtol=1e-6)
def _generate_samples(self, generate_mask=False):
def _apply_logit(p): return np.log(p / (1 - p))
@@ -51,5 +51,10 @@ class TestSigmoidFocalLoss(unittest.TestCase):
pred, tgt, _ = self._generate_samples()
self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, alpha=0.58, gamma=2, reduction=reduction)
def test_loss_correct_ratio_mask(self):
for reduction in ["mean", "sum", "none"]:
pred, tgt, mask = self._generate_samples(generate_mask=True)
self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, mask=mask, alpha=0.58, gamma=2, reduction=reduction)
if __name__ == '__main__':
unittest.main()