From 65c561a61854db536c37897af558552a9ca5ec35 Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Fri, 25 Oct 2024 21:18:34 -0700 Subject: [PATCH] update image to be float32 --- examples/mlperf/dataloader.py | 2 +- extra/datasets/openimages.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 6471880244..25bad412fd 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -403,7 +403,7 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b queue_in, queue_out = Queue(), Queue() procs, data_out_count = [], [0] * batch_count - shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.uint8) + shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.float32) shm_y_boxes, Y_boxes = _setup_shared_mem("retinanet_y_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32) shm_y_labels, Y_labels = _setup_shared_mem("retinanet_y_labels", (batch_size * batch_count, 120087), dtypes.int64) diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index cf7fe4dee7..17f6e02f7c 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -190,7 +190,7 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si import torchvision.transforms.functional as F img_size = img.size[::-1] img = F.resize(img, size=size) - img = np.array(img) + img = np.array(img, dtype=np.float32) if tgt is not None: ratios = [s / s_orig for s, s_orig in zip(size, img_size)]