mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove unnecessary targets from validation dataloader
This commit is contained in:
@@ -396,11 +396,8 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
if img_ids is not None:
|
||||
img_ids[idx] = img_id
|
||||
|
||||
if isinstance(boxes, list):
|
||||
boxes[idx] = tgt["boxes"]
|
||||
|
||||
if isinstance(labels, list):
|
||||
labels[idx] = tgt["labels"]
|
||||
if img_sizes is not None:
|
||||
img_sizes[idx] = tgt["image_size"]
|
||||
|
||||
queue_in.put((idx, img, tgt))
|
||||
|
||||
@@ -419,10 +416,10 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
|
||||
|
||||
if val:
|
||||
matches, anchors = None, None
|
||||
img_ids, boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
|
||||
boxes, labels, matches, anchors = None, None, None, None
|
||||
img_ids, img_sizes = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
|
||||
else:
|
||||
img_ids = None
|
||||
img_ids, img_sizes = None, None
|
||||
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
@@ -469,8 +466,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
if val:
|
||||
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
|
||||
img_ids[bc * batch_size:(bc + 1) * batch_size],
|
||||
boxes[bc * batch_size:(bc + 1) * batch_size],
|
||||
labels[bc * batch_size:(bc + 1) * batch_size],
|
||||
img_sizes[bc * batch_size:(bc + 1) * batch_size],
|
||||
Cookie(bc))
|
||||
else:
|
||||
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
|
||||
|
||||
2
test/external/external_test_datasets.py
vendored
2
test/external/external_test_datasets.py
vendored
@@ -157,7 +157,7 @@ class TestOpenImagesDataset(ExternalTestDatasets):
|
||||
ref_dataloader = self._create_ref_dataloader(base_dir, ann_file, "val")
|
||||
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
|
||||
|
||||
for ((tinygrad_img, _, _, _, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
|
||||
for ((tinygrad_img, _, _, _), (ref_img, _)) in zip(tinygrad_dataloader, ref_dataloader):
|
||||
ref_img, _ = transform(ref_img.unsqueeze(0))
|
||||
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user