remove unnecessary targets from validation dataloader

This commit is contained in:
Francis Lata
2025-02-03 19:15:30 +00:00
parent 932cf4b7f2
commit f02cce0049
2 changed files with 7 additions and 11 deletions

View File

@@ -396,11 +396,8 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
if img_ids is not None:
img_ids[idx] = img_id
if isinstance(boxes, list):
boxes[idx] = tgt["boxes"]
if isinstance(labels, list):
labels[idx] = tgt["labels"]
if img_sizes is not None:
img_sizes[idx] = tgt["image_size"]
queue_in.put((idx, img, tgt))
@@ -419,10 +416,10 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
if val:
matches, anchors = None, None
img_ids, boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
boxes, labels, matches, anchors = None, None, None, None
img_ids, img_sizes = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
else:
img_ids = None
img_ids, img_sizes = None, None
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
@@ -469,8 +466,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
if val:
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
img_ids[bc * batch_size:(bc + 1) * batch_size],
boxes[bc * batch_size:(bc + 1) * batch_size],
labels[bc * batch_size:(bc + 1) * batch_size],
img_sizes[bc * batch_size:(bc + 1) * batch_size],
Cookie(bc))
else:
yield (imgs[bc * batch_size:(bc + 1) * batch_size],

View File

@@ -157,7 +157,7 @@ class TestOpenImagesDataset(ExternalTestDatasets):
ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val")
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
for ((tinygrad_img, _, _, _, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
for ((tinygrad_img, _, _, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
ref_img, _ = transform(ref_img.unsqueeze(0))
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())