From 2214d13b3d2667bc8ffd673e87da079864ccbb36 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Wed, 11 Dec 2024 23:00:48 +0000 Subject: [PATCH] add missing test and cleanup focal loss --- examples/mlperf/losses.py | 4 ++-- test/external/external_test_losses.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/mlperf/losses.py b/examples/mlperf/losses.py index 0c352d79ac..9bf4d2cfd4 100644 --- a/examples/mlperf/losses.py +++ b/examples/mlperf/losses.py @@ -22,9 +22,9 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor] = None, al if reduction == "mean": if mask is not None: - loss = (avg_loss := loss.sum(0).mean()) / (Tensor.ones_like(avg_loss) if mask is None else mask.sum()) + loss = loss.sum(axis=0).mean() / mask.sum() else: loss = loss.mean() - elif reduction == "sum": loss = loss.sum(0).sum() + elif reduction == "sum": loss = loss.sum(axis=0).sum() return loss diff --git a/test/external/external_test_losses.py b/test/external/external_test_losses.py index 0c6ffb7cc5..9281c8271c 100644 --- a/test/external/external_test_losses.py +++ b/test/external/external_test_losses.py @@ -30,7 +30,7 @@ class TestSigmoidFocalLoss(unittest.TestCase): ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs) # NOTE: since boolean indexing is not supported in tinygrad, compare the sum instead. - np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), atol=1e-5) + np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), rtol=1e-6) def _generate_samples(self, generate_mask=False): def _apply_logit(p): return np.log(p / (1 - p)) @@ -51,5 +51,10 @@ class TestSigmoidFocalLoss(unittest.TestCase): pred, tgt, _ = self._generate_samples() self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, alpha=0.58, gamma=2, reduction=reduction) + def test_loss_correct_ratio_mask(self): + for reduction in ["mean", "sum", "none"]: + pred, tgt, mask = self._generate_samples(generate_mask=True) + self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, mask=mask, alpha=0.58, gamma=2, reduction=reduction) + if __name__ == '__main__': unittest.main() \ No newline at end of file