From 727eb52ccf89a86225d7469cc18ee82773964934 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 17 Jan 2025 10:59:39 -0800 Subject: [PATCH] reenable test --- test/external/external_test_datasets.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/external/external_test_datasets.py b/test/external/external_test_datasets.py index ca7cb58388..eed379588c 100644 --- a/test/external/external_test_datasets.py +++ b/test/external/external_test_datasets.py @@ -150,16 +150,16 @@ class TestOpenImagesDataset(ExternalTestDatasets): np.testing.assert_equal(tinygrad_boxes[0].numpy(), ref_boxes.numpy()) np.testing.assert_equal(tinygrad_labels[0].numpy(), ref_labels.numpy()) - # def test_validation_set(self): - # base_dir, ann_file = self._create_samples(subset := "validation") - # img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4)) - # tinygrad_dataloader = self._create_tinygrad_dataloader(base_dir, ann_file, subset, anchors.numpy()) - # ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val") - # transform = GeneralizedRCNNTransform(img_size, img_mean, img_std) + def test_validation_set(self): + base_dir, ann_file = self._create_samples(subset := "validation") + img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4)) + tinygrad_dataloader = self._create_tinygrad_dataloader(base_dir, ann_file, subset, anchors.numpy()) + ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val") + transform = GeneralizedRCNNTransform(img_size, img_mean, img_std) - # for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader): - # ref_img, _ = transform(ref_img.unsqueeze(0)) - # np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy()) + for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader): + ref_img, _ = transform(ref_img.unsqueeze(0)) + np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy()) if __name__ == '__main__': unittest.main()