fix validation test case

This commit is contained in:
Francis Lata
2025-01-28 03:06:43 -08:00
parent fc957e7377
commit dcd1941b94

View File

@@ -157,7 +157,7 @@ class TestOpenImagesDataset(ExternalTestDatasets):
ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val")
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
for ((tinygrad_img, _, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
ref_img, _ = transform(ref_img.unsqueeze(0))
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())