diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 9aec24f4a2..56a90d542c 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -351,8 +351,8 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool= ### RetinaNet def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, - imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Tensor, - anchors:Tensor, img_ids:Tensor, seed:Optional[int] = None): + imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Optional[Tensor] = None, + anchors:Optional[Tensor] = None, seed:Optional[int] = None): from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize from examples.mlperf.helpers import box_iou, find_matches, generate_anchors import torch @@ -391,8 +391,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh def _enqueue_batch(bc): for idx in range(bc * batch_size, (bc+1) * batch_size): img = dataset.loadImgs(next(dataset_iter))[0] - ann = dataset.loadAnns(dataset.getAnnIds((img_id:=img["id"]))) - img_ids[idx] = img_id + ann = dataset.loadAnns(dataset.getAnnIds(img["id"])) queue_in.put((idx, img, ann)) def _setup_shared_mem(shm_name:str, size:Tuple[int, ...], dtype:dtypes) -> Tuple[shared_memory.SharedMemory, Tensor]: @@ -408,12 +407,15 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh procs, data_out_count = [], [0] * batch_count shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.float32) - shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32) - shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64) - shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64) - shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64) - img_ids = [None] * (batch_size * batch_count) + if val: + matches, anchors = None, None + boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count) + else: + shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32) + shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64) + shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64) + shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64) shutdown = False class Cookie: @@ -435,8 +437,8 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh for _ in range(cpu_count()): proc = Process( target=load_retinanet_data, - args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels, matches, anchors, img_ids), - kwargs={"seed": seed} + args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels), + kwargs={"matches": matches, "anchors": anchors, "seed": seed} ) proc.daemon = True proc.start() @@ -455,7 +457,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh if val: yield (imgs[bc * batch_size:(bc + 1) * batch_size], - img_ids[bc * batch_size:(bc + 1) * batch_size], + image_ids[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)) else: yield (imgs[bc * batch_size:(bc + 1) * batch_size], @@ -478,17 +480,21 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh for proc in procs: proc.join() shm_imgs.close() - shm_boxes.close() - shm_labels.close() - shm_matches.close() - shm_anchors.close() + + if not val: + shm_boxes.close() + shm_labels.close() + shm_matches.close() + shm_anchors.close() try: shm_imgs.unlink() - shm_boxes.unlink() - shm_labels.unlink() - shm_matches.unlink() - shm_anchors.unlink() + + if not val: + shm_boxes.unlink() + shm_labels.unlink() + shm_matches.unlink() + shm_anchors.unlink() except FileNotFoundError: # happens with BENCHMARK set pass