From 932cf4b7f2d409e3d3a89ea376018e1c3e70529d Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Sun, 2 Feb 2025 19:21:46 +0000 Subject: [PATCH] fix img_ids repeating its values --- examples/mlperf/dataloader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 0f24ffee3b..49c2dbf3a2 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -393,6 +393,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh ann = dataset.loadAnns(dataset.getAnnIds(img_id:=img["id"])) tgt = prepare_target(ann, img_id, (img["height"], img["width"])) + if img_ids is not None: + img_ids[idx] = img_id + if isinstance(boxes, list): boxes[idx] = tgt["boxes"] @@ -417,8 +420,9 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh if val: matches, anchors = None, None - boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count) + img_ids, boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count), [None] * (batch_size * batch_count) else: + img_ids = None shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32) shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64) shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64) @@ -464,7 +468,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh if val: yield (imgs[bc * batch_size:(bc + 1) * batch_size], - image_ids[bc * batch_size:(bc + 1) * batch_size], + img_ids[bc * batch_size:(bc + 1) * batch_size], boxes[bc * batch_size:(bc + 1) * batch_size], labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc))