diff --git a/test/external/external_test_losses.py b/test/external/external_test_losses.py index 9281c8271c..3ed8b334a9 100644 --- a/test/external/external_test_losses.py +++ b/test/external/external_test_losses.py @@ -7,54 +7,39 @@ import numpy as np import torch import unittest -class TestDiceCELoss(unittest.TestCase): +class TestLoss(unittest.TestCase): + def setUp(self): + np.random.seed(1337) + + def assert_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, rtol=1e-07, atol=0, **kwargs): + tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs) + ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs) + np.testing.assert_allclose(tinygrad_metrics_res.numpy(), ref_metrics_res.numpy(), rtol=rtol, atol=atol) + +class TestDiceCELoss(TestLoss): def setUp(self): np.random.seed(1337) def test_loss(self): pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8) - tinygrad_metrics_res = dice_ce_loss(Tensor(pred), Tensor(label)).numpy() - ref_metrics_res = DiceCELoss(True, True, "NCDHW", False)(torch.from_numpy(pred), torch.from_numpy(label)).numpy() - np.testing.assert_allclose(tinygrad_metrics_res, ref_metrics_res, atol=1e-4) + tinygrad_metrics_res, ref_metrics_res = dice_ce_loss, DiceCELoss(True, True, "NCDHW", False) + self.assert_loss(pred, label, tinygrad_metrics_res, ref_metrics_res, atol=1e-4) -class TestSigmoidFocalLoss(unittest.TestCase): - def setUp(self): - np.random.seed(1337) - - def _test_loss(self, pred, tgt, tinygrad_metrics, ref_metrics, mask=None, **kwargs): - if mask is not None: - tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), mask=Tensor(mask).reshape(-1, 1), **kwargs) - ref_metrics_res = ref_metrics(torch.from_numpy(pred[mask]), torch.from_numpy(tgt[mask]), **kwargs) - else: - tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(tgt), **kwargs) - ref_metrics_res = ref_metrics(torch.from_numpy(pred), torch.from_numpy(tgt), **kwargs) - - # NOTE: since boolean indexing is not supported in tinygrad, compare the sum instead. - np.testing.assert_allclose(tinygrad_metrics_res.sum().numpy(), ref_metrics_res.sum().numpy(), rtol=1e-6) - - def _generate_samples(self, generate_mask=False): +class TestSigmoidFocalLoss(TestLoss): + def _generate_samples(self): def _apply_logit(p): return np.log(p / (1 - p)) - return _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32), np.random.randint(0, 2, (5,), dtype=np.bool) if generate_mask else None + return _apply_logit(np.random.rand(5,2).astype(np.float32)), np.random.randint(0, 2, (5, 2)).astype(np.float32) def test_loss_equal_to_ce(self): for reduction in ["mean", "sum", "none"]: - pred, tgt, _ = self._generate_samples() - self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, alpha=-1, gamma=0, reduction=reduction) - - def test_loss_equal_to_ce_mask(self): - for reduction in ["mean", "sum", "none"]: - pred, tgt, mask = self._generate_samples(generate_mask=True) - self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, mask=mask, alpha=-1, gamma=0, reduction=reduction) + pred, tgt = self._generate_samples() + self.assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=-1, gamma=0, reduction=reduction) def test_loss_correct_ratio(self): for reduction in ["mean", "sum", "none"]: - pred, tgt, _ = self._generate_samples() - self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, alpha=0.58, gamma=2, reduction=reduction) + pred, tgt = self._generate_samples() + self.assert_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, rtol=1e-4, alpha=0.58, gamma=2, reduction=reduction) - def test_loss_correct_ratio_mask(self): - for reduction in ["mean", "sum", "none"]: - pred, tgt, mask = self._generate_samples(generate_mask=True) - self._test_loss(pred, tgt, sigmoid_focal_loss, ref_sigmoid_focal_loss, mask=mask, alpha=0.58, gamma=2, reduction=reduction) if __name__ == '__main__': unittest.main() \ No newline at end of file