From 8ca848d542b861ecc48ef28b93dc199706fbc27a Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Sat, 5 Oct 2024 18:21:15 -0700 Subject: [PATCH] reuse existing prepare_target --- examples/mlperf/dataloader.py | 4 ++-- extra/datasets/openimages.py | 18 ------------------ 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index bee4b26bc0..40828f3ff8 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -357,12 +357,12 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool= ### RetinaNet def load_retinanet_data(base_dir:Path, queue_in:Queue, queue_out:Queue, X:Tensor): - from extra.datasets.openimages import image_load, transform_coco_polys_to_mask + from extra.datasets.openimages import image_load, prepare_target while (data:=queue_in.get()) is not None: idx, img, ann = data img_id = img["id"] img, img_size = image_load(base_dir, "train", img["file_name"]) # TODO: resize this with the target! - img, tgt = transform_coco_polys_to_mask(img, img_id, img_size, ann) + tgt = prepare_target(ann, img_id, img_size) X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes() diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index d92fb98aba..f9d2cbadf6 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -179,24 +179,6 @@ def download_dataset(base_dir:Path, subset:str) -> Path: return ann_file -def transform_coco_polys_to_mask(img, img_id, img_size, ann, filter_is_crowd:bool=True): - w, h = img_size - if filter_is_crowd: ann = [obj for obj in ann if obj["iscrowd"] == 0] - - boxes = np.array([obj["bbox"] for obj in ann]) - boxes = boxes.reshape(-1, 4) - boxes[:, 2:] += boxes[:, :2] - boxes[:, 0::2].clip(0, w) - boxes[:, 1::2].clip(0, h) - - classes = np.array([obj["category_id"] for obj in ann]) - keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) - boxes, classes = boxes[keep], classes[keep] - - area, is_crowd = [obj["area"] for obj in ann], [obj["iscrowd"] for obj in ann] - target = {"boxes": boxes, "labels": classes, "image_id": img_id, "area": area, "iscrowd": is_crowd} - return img, target - if __name__ == "__main__": download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")