mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add validation set test
This commit is contained in:
@@ -363,24 +363,28 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
|
||||
|
||||
while (data:=queue_in.get()) is not None:
|
||||
idx, img, ann = data
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
img_id = img["id"]
|
||||
img = image_load(base_dir, img["subset"], img["file_name"])
|
||||
tgt = prepare_target(ann, img_id, img.size[::-1])
|
||||
img, tgt = random_horizontal_flip(img, tgt)
|
||||
img, tgt, _ = resize(img, tgt=tgt)
|
||||
match_quality_matrix = box_iou(tgt["boxes"], anchors)
|
||||
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
|
||||
matches = np.clip(matches, 0, None)
|
||||
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
|
||||
|
||||
if val:
|
||||
img, _ = resize(img)
|
||||
else:
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
img, tgt = random_horizontal_flip(img, tgt)
|
||||
img, tgt, _ = resize(img, tgt=tgt)
|
||||
match_quality_matrix = box_iou(tgt["boxes"], anchors)
|
||||
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
|
||||
matches = np.clip(matches, 0, None)
|
||||
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
|
||||
|
||||
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
|
||||
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
|
||||
|
||||
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
|
||||
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
|
||||
|
||||
queue_out.put(idx)
|
||||
queue_out.put(None)
|
||||
@@ -440,7 +444,11 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
if data_out_count[bc] == batch_size: break
|
||||
|
||||
data_out_count[bc] = 0
|
||||
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
|
||||
|
||||
if val:
|
||||
yield X[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
|
||||
else:
|
||||
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
|
||||
finally:
|
||||
shutdown = True
|
||||
|
||||
|
||||
15
test/external/external_test_datasets.py
vendored
15
test/external/external_test_datasets.py
vendored
@@ -1,7 +1,7 @@
|
||||
from extra.datasets.kits19 import iterate, preprocess
|
||||
from examples.mlperf.dataloader import batch_load_unet3d, batch_load_retinanet
|
||||
from test.external.mlperf_retinanet.openimages import get_openimages, postprocess_targets
|
||||
from test.external.mlperf_retinanet.presets import DetectionPresetTrain
|
||||
from test.external.mlperf_retinanet.presets import DetectionPresetTrain, DetectionPresetEval
|
||||
from test.external.mlperf_retinanet.transforms import GeneralizedRCNNTransform
|
||||
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
|
||||
from tinygrad.helpers import temp
|
||||
@@ -121,10 +121,10 @@ class TestOpenImagesDataset(ExternalTestDatasets):
|
||||
|
||||
return base_dir, ann_file
|
||||
|
||||
def _create_ref_dataloader(self, subset, batch_size=1):
|
||||
def _create_ref_dataloader(self, subset):
|
||||
self._set_seed()
|
||||
base_dir, ann_file = self._create_samples(subset)
|
||||
transforms = DetectionPresetTrain("hflip")
|
||||
transforms = DetectionPresetTrain("hflip") if subset == "train" else DetectionPresetEval()
|
||||
dataset = get_openimages(ann_file.stem, base_dir, subset, transforms)
|
||||
return iter(dataset)
|
||||
|
||||
@@ -150,5 +150,14 @@ class TestOpenImagesDataset(ExternalTestDatasets):
|
||||
np.testing.assert_equal(tinygrad_boxes[0].numpy(), ref_boxes.numpy())
|
||||
np.testing.assert_equal(tinygrad_labels[0].numpy(), ref_labels.numpy())
|
||||
|
||||
def test_validation_set(self):
|
||||
img_size, img_mean, img_std, anchors = (800, 800), [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], torch.ones((120087, 4))
|
||||
tinygrad_dataloader, ref_dataloader = self._create_tinygrad_dataloader("validation", anchors.numpy()), self._create_ref_dataloader("val")
|
||||
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
|
||||
|
||||
for ((tinygrad_img, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
|
||||
ref_img, _ = transform(ref_img.unsqueeze(0))
|
||||
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user