reenable test

This commit is contained in:
Francis Lata
2025-01-17 10:59:39 -08:00
parent 8371688f80
commit 727eb52ccf

View File

@@ -150,16 +150,16 @@ class TestOpenImagesDataset(ExternalTestDatasets):
np.testing.assert_equal(tinygrad_boxes[0].numpy(), ref_boxes.numpy())
np.testing.assert_equal(tinygrad_labels[0].numpy(), ref_labels.numpy())
# def test_validation_set(self):
# base_dir, ann_file = self._create_samples(subset := "validation")
# img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4))
# tinygrad_dataloader = self._create_tinygrad_dataloader(base_dir, ann_file, subset, anchors.numpy())
# ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val")
# transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
def test_validation_set(self):
base_dir, ann_file = self._create_samples(subset := "validation")
img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4))
tinygrad_dataloader = self._create_tinygrad_dataloader(base_dir, ann_file, subset, anchors.numpy())
ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val")
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
# for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
# ref_img, _ = transform(ref_img.unsqueeze(0))
# np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
ref_img, _ = transform(ref_img.unsqueeze(0))
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
if __name__ == '__main__':
unittest.main()