add validation set test

This commit is contained in:
Francis Lata
2024-10-25 22:55:49 -07:00
parent 2586555bd3
commit 6e3efd4ed6
2 changed files with 34 additions and 17 deletions

View File

@@ -363,24 +363,28 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
while (data:=queue_in.get()) is not None:
idx, img, ann = data
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
img_id = img["id"]
img = image_load(base_dir, img["subset"], img["file_name"])
tgt = prepare_target(ann, img_id, img.size[::-1])
img, tgt = random_horizontal_flip(img, tgt)
img, tgt, _ = resize(img, tgt=tgt)
match_quality_matrix = box_iou(tgt["boxes"], anchors)
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
matches = np.clip(matches, 0, None)
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
if val:
img, _ = resize(img)
else:
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
img, tgt = random_horizontal_flip(img, tgt)
img, tgt, _ = resize(img, tgt=tgt)
match_quality_matrix = box_iou(tgt["boxes"], anchors)
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
matches = np.clip(matches, 0, None)
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
queue_out.put(idx)
queue_out.put(None)
@@ -440,7 +444,11 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
if data_out_count[bc] == batch_size: break
data_out_count[bc] = 0
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
if val:
yield X[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
else:
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
finally:
shutdown = True

View File

@@ -1,7 +1,7 @@
from extra.datasets.kits19 import iterate, preprocess
from examples.mlperf.dataloader import batch_load_unet3d, batch_load_retinanet
from test.external.mlperf_retinanet.openimages import get_openimages, postprocess_targets
from test.external.mlperf_retinanet.presets import DetectionPresetTrain
from test.external.mlperf_retinanet.presets import DetectionPresetTrain, DetectionPresetEval
from test.external.mlperf_retinanet.transforms import GeneralizedRCNNTransform
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
from tinygrad.helpers import temp
@@ -121,10 +121,10 @@ class TestOpenImagesDataset(ExternalTestDatasets):
return base_dir, ann_file
def _create_ref_dataloader(self, subset, batch_size=1):
def _create_ref_dataloader(self, subset):
self._set_seed()
base_dir, ann_file = self._create_samples(subset)
transforms = DetectionPresetTrain("hflip")
transforms = DetectionPresetTrain("hflip") if subset == "train" else DetectionPresetEval()
dataset = get_openimages(ann_file.stem, base_dir, subset, transforms)
return iter(dataset)
@@ -150,5 +150,14 @@ class TestOpenImagesDataset(ExternalTestDatasets):
np.testing.assert_equal(tinygrad_boxes[0].numpy(), ref_boxes.numpy())
np.testing.assert_equal(tinygrad_labels[0].numpy(), ref_labels.numpy())
def test_validation_set(self):
img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4))
tinygrad_dataloader, ref_dataloader = self._create_tinygrad_dataloader("validation", anchors.numpy()), self._create_ref_dataloader("val")
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
ref_img, _ = transform(ref_img.unsqueeze(0))
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
if __name__ == '__main__':
unittest.main()