got dataloader with normalize working

This commit is contained in:
Francis Lata
2024-10-25 20:25:07 -07:00
parent 967438ca71
commit 4b21a8fb8d
4 changed files with 22 additions and 23 deletions

View File

@@ -1,5 +1,5 @@
import os, random, pickle, queue
from typing import List, Tuple
from typing import List, Tuple, Optional
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
@@ -356,17 +356,18 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
### RetinaNet
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor, Y_boxes:Tensor, Y_labels:Tensor, anchors:np.ndarray):
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor, Y_boxes:Tensor, Y_labels:Tensor, anchors:np.ndarray, seed:Optional[int]=None):
from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize
from examples.mlperf.helpers import box_iou, find_matches
import torch
while (data:=queue_in.get()) is not None:
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
idx, img, ann = data
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
img_id = img["id"]
img = image_load(base_dir, img["subset"], img["file_name"])
tgt = prepare_target(ann, img_id, img.size[::-1])
@@ -384,7 +385,7 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
queue_out.put(idx)
queue_out.put(None)
def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, batch_size:int=32, seed:int=None):
def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, batch_size:int=32, shuffle:bool=True, seed:Optional[int]=None):
def _enqueue_batch(bc):
for idx in range(bc * batch_size, (bc+1) * batch_size):
img = dataset.loadImgs(next(dataset_iter))[0]
@@ -415,16 +416,16 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
try: _enqueue_batch(self.bc)
except StopIteration: pass
# def shuffle_indices(file_indices, seed=None):
# rng = random.Random(seed)
# rng.shuffle(file_indices)
def shuffle_indices(indices, seed):
rng = random.Random(seed)
rng.shuffle(indices)
# if shuffle: shuffle_indices(file_indices, seed=seed)
if shuffle: shuffle_indices(image_ids, seed=seed)
dataset_iter = iter(image_ids)
try:
for _ in range(cpu_count()):
proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, anchors))
proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, anchors, seed))
proc.daemon = True
proc.start()
procs.append(proc)

View File

@@ -193,8 +193,8 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si
img = np.array(img)
if tgt is not None:
ratios = [s / s_orig for s, s_orig in zip(size, img.shape[::-1])]
ratio_w, ratio_h = ratios
ratios = [s / s_orig for s, s_orig in zip(size, img_size)]
ratio_h, ratio_w = ratios
x_min, y_min, x_max, y_max = [tgt["boxes"][:, i] for i in range(tgt["boxes"].shape[-1])]
x_min = x_min * ratio_w
x_max = x_max * ratio_w

View File

@@ -1,5 +1,5 @@
from extra.datasets.kits19 import iterate, preprocess
from extra.datasets.openimages import download_dataset
from extra.datasets.openimages import normalize
from examples.mlperf.dataloader import batch_load_unet3d, batch_load_retinanet
from test.external.mlperf_retinanet.openimages import get_openimages, postprocess_targets
from test.external.mlperf_retinanet.presets import DetectionPresetTrain, DetectionPresetEval
@@ -86,8 +86,6 @@ class TestKiTS19Dataset(ExternalTestDatasets):
class TestOpenImagesDataset(ExternalTestDatasets):
def _create_samples(self, subset):
self._set_seed()
os.makedirs(Path(base_dir:=tempfile.gettempdir() + "/openimages") / f"{subset}/data", exist_ok=True)
os.makedirs(base_dir / Path(f"{subset}/labels"), exist_ok=True)
@@ -111,15 +109,16 @@ class TestOpenImagesDataset(ExternalTestDatasets):
return base_dir, ann_file
def _create_ref_dataloader(self, subset, batch_size=1):
self._set_seed()
base_dir, ann_file = self._create_samples(subset)
transforms = DetectionPresetTrain("hflip")
dataset = get_openimages(ann_file.stem, base_dir, subset, transforms)
return iter(dataset)
def _create_tinygrad_dataloader(self, subset, anchors, batch_size=1):
def _create_tinygrad_dataloader(self, subset, anchors, batch_size=1, seed=42):
base_dir, ann_file = self._create_samples(subset)
dataset = COCO(ann_file)
dataloader = batch_load_retinanet(dataset, subset == "validation", anchors, Path(base_dir), batch_size=batch_size)
dataloader = batch_load_retinanet(dataset, subset == "validation", anchors, Path(base_dir), batch_size=batch_size, shuffle=False, seed=seed)
return iter(dataloader)
def test_training_set(self):
@@ -128,7 +127,6 @@ class TestOpenImagesDataset(ExternalTestDatasets):
transform = GeneralizedRCNNTransform(img_size, img_mean, img_std)
for ((tinygrad_img, tinygrad_boxes, tinygrad_labels, _), (ref_img, ref_tgt)) in zip(tinygrad_dataloader, ref_dataloader):
self._set_seed()
ref_tgt = [ref_tgt]
ref_img, ref_tgt = transform(ref_img.unsqueeze(0), ref_tgt)
@@ -136,8 +134,8 @@ class TestOpenImagesDataset(ExternalTestDatasets):
ref_boxes, ref_labels = ref_tgt[0]["boxes"], ref_tgt[0]["labels"]
np.testing.assert_equal(tinygrad_img.numpy(), ref_img.tensors.transpose(1, 3).numpy())
# print(f"{tinygrad_img.shape=} {tinygrad_boxes.shape=} {tinygrad_labels.shape=}")
# print(f"{ref_boxes.shape=} {ref_labels.shape=} {ref_img.tensors.shape=}")
np.testing.assert_equal(tinygrad_boxes[0].numpy(), ref_boxes.numpy())
np.testing.assert_equal(tinygrad_labels[0].numpy(), ref_labels.numpy())
if __name__ == '__main__':
unittest.main()

View File

@@ -181,7 +181,7 @@ class GeneralizedRCNNTransform(nn.Module):
if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape))
image = self.normalize(image)
# image = self.normalize(image)
image, target_index = self.resize(image, target_index)
images[i] = image
if targets is not None and target_index is not None: