diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 40c54f2080..87d3060922 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -232,6 +232,9 @@ jobs: - if: ${{ matrix.task == 'onnx' }} name: Test MLPerf metrics run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20 + - if: ${{ matrix.task == 'onnx' }} + name: Test MLPerf datasets + run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20 - if: ${{ matrix.task == 'onnx' }} name: Test THREEFRY run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 4f2faa9a9b..662e998a7f 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -63,13 +63,13 @@ def eval_resnet(): def eval_unet3d(): # UNet3D from extra.models.unet3d import UNet3D - from extra.datasets.kits19 import iterate, sliding_window_inference + from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files from examples.mlperf.metrics import dice_score mdl = UNet3D() mdl.load_from_pretrained() s = 0 st = time.perf_counter() - for i, (image, label) in enumerate(iterate(), start=1): + for i, (image, label) in enumerate(iterate(get_val_files()), start=1): mt = time.perf_counter() pred, label = sliding_window_inference(mdl, image, label) et = time.perf_counter() diff --git a/extra/datasets/kits19.py b/extra/datasets/kits19.py index bb2b1bfaec..02dc5e67fe 100644 --- a/extra/datasets/kits19.py +++ b/extra/datasets/kits19.py @@ -3,13 +3,16 @@ import functools from pathlib import Path import numpy as np import nibabel as nib -from scipy import signal +from scipy import signal, ndimage +import os import torch import torch.nn.functional as F +from tqdm import tqdm from tinygrad.tensor import Tensor from tinygrad.helpers import fetch BASEDIR = Path(__file__).parent / "kits19" / "data" +PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" """ To download the dataset: @@ -23,6 +26,10 @@ mv kits19 extra/datasets ``` """ +@functools.lru_cache(None) +def get_train_files(): + return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()]) + @functools.lru_cache(None) def get_val_files(): data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text() @@ -65,15 +72,42 @@ def preprocess(file_path): image, label = pad_to_min_shape(image, label) return image, label -def iterate(val=True, shuffle=False): - if not val: raise NotImplementedError - files = get_val_files() +def preprocess_dataset(filenames, preprocessed_dir, val): + preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None + if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir) + for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"): + case = os.path.basename(fn) + image, label = preprocess(fn) + image, label = image.astype(np.float32), label.astype(np.uint8) + np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False) + np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False) + +def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1): order = list(range(0, len(files))) + preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None if shuffle: random.shuffle(order) - for file in files: - X, Y = preprocess(file) - X = np.expand_dims(X, axis=0) - yield (X, Y) + for i in range(0, len(files), bs): + samples = [] + for i in order[i:i+bs]: + if preprocessed_dataset_dir is not None: + x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy" + if x_cached_path.exists() and y_cached_path.exists(): + samples += [(np.load(x_cached_path), np.load(y_cached_path))] + else: samples += [preprocess(files[i])] + X, Y = [x[0] for x in samples], [x[1] for x in samples] + if val: + yield X[0][None], Y[0] + else: + X_preprocessed, Y_preprocessed = [], [] + for x, y in zip(X, Y): + x, y = rand_balanced_crop(x, y) + x, y = rand_flip(x, y) + x, y = x.astype(np.float32), y.astype(np.uint8) + x = random_brightness_augmentation(x) + x = gaussian_noise(x) + X_preprocessed.append(x) + Y_preprocessed.append(y) + yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0) def gaussian_kernel(n, std): gaussian_1d = signal.windows.gaussian(n, std) @@ -126,6 +160,73 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]] return result, labels +def rand_flip(image, label, axis=(1, 2, 3)): + prob = 1 / len(axis) + for ax in axis: + if random.random() < prob: + image = np.flip(image, axis=ax).copy() + label = np.flip(label, axis=ax).copy() + return image, label + +def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1): + if random.random() < prob: + factor = np.random.uniform(low=low, high=high, size=1) + image = (image * (1 + factor)).astype(image.dtype) + return image + +def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1): + if random.random() < prob: + scale = np.random.uniform(low=0.0, high=std) + noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype) + image += noise + return image + +def _rand_foreg_cropb(image, label, patch_size): + def adjust(foreg_slice, label, idx): + diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start) + sign = -1 if diff < 0 else 1 + diff = abs(diff) + ladj = 0 if diff == 0 else random.randrange(diff) + hadj = diff - ladj + low = max(0, foreg_slice[idx].start - sign * ladj) + high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj) + diff = patch_size[idx - 1] - (high - low) + if diff > 0 and low == 0: high += diff + elif diff > 0: low -= diff + return low, high + + cl = np.random.choice(np.unique(label[label > 0])) + foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0]) + foreg_slices = [x for x in foreg_slices if x is not None] + slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices] + slice_idx = np.argsort(slice_volumes)[-2:] + foreg_slices = [foreg_slices[i] for i in slice_idx] + if not foreg_slices: return _rand_crop(image, label) + foreg_slice = foreg_slices[random.randrange(len(foreg_slices))] + low_x, high_x = adjust(foreg_slice, label, 1) + low_y, high_y = adjust(foreg_slice, label, 2) + low_z, high_z = adjust(foreg_slice, label, 3) + image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] + label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] + return image, label + +def _rand_crop(image, label, patch_size): + ranges = [s - p for s, p in zip(image.shape[1:], patch_size)] + cord = [0 if x == 0 else random.randrange(x) for x in ranges] + low_x, high_x = cord[0], cord[0] + patch_size[0] + low_y, high_y = cord[1], cord[1] + patch_size[1] + low_z, high_z = cord[2], cord[2] + patch_size[2] + image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] + label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] + return image, label + +def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4): + if random.random() < oversampling: + image, label = _rand_foreg_cropb(image, label, patch_size) + else: + image, label = _rand_crop(image, label, patch_size) + return image, label + if __name__ == "__main__": - for X, Y in iterate(): + for X, Y in iterate(get_val_files()): print(X.shape, Y.shape) diff --git a/setup.py b/setup.py index e673c0025f..b54e19bcf0 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ setup(name='tinygrad', "librosa", "networkx", "hypothesis", + "nibabel", ], 'docs': [ "mkdocs-material", diff --git a/test/external/external_test_datasets.py b/test/external/external_test_datasets.py new file mode 100644 index 0000000000..46a01d6444 --- /dev/null +++ b/test/external/external_test_datasets.py @@ -0,0 +1,71 @@ +from extra.datasets.kits19 import iterate, preprocess +from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal +from tinygrad.helpers import temp +from pathlib import Path + +import nibabel as nib +import numpy as np +import os +import random +import tempfile +import unittest + +class ExternalTestDatasets(unittest.TestCase): + def _set_seed(self): + np.random.seed(42) + random.seed(42) + + def _create_sample(self): + self._set_seed() + + img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8) + img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4)) + + os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True) + nib.save(img, temp("case_0000/imaging.nii.gz")) + nib.save(lbl, temp("case_0000/segmentation.nii.gz")) + + return Path(tempfile.gettempdir()) / "case_0000" + + def _create_kits19_ref_sample(self, sample_pth, val): + self._set_seed() + + img, lbl = preprocess(sample_pth) + dataset = "val" if val else "train" + preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy") + + os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True) + np.save(preproc_img_pth, img, allow_pickle=False) + np.save(preproc_lbl_pth, lbl, allow_pickle=False) + + if val: + dataset = PytVal([preproc_img_pth], [preproc_lbl_pth]) + else: + dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4) + + return dataset[0] + + def _create_kits19_tinygrad_sample(self, sample_pth, val): + self._set_seed() + return next(iterate([sample_pth], val=val)) + + def test_kits19_training_set(self): + sample_pth = self._create_sample() + + tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False) + ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False) + + np.testing.assert_equal(tinygrad_img[:, 0], ref_img) + np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl) + + def test_kits19_validation_set(self): + sample_pth = self._create_sample() + + tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True) + ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True) + + np.testing.assert_equal(tinygrad_img[:, 0], ref_img) + np.testing.assert_equal(tinygrad_lbl, ref_lbl) + +if __name__ == '__main__': + unittest.main() diff --git a/test/external/mlperf_unet3d/kits19.py b/test/external/mlperf_unet3d/kits19.py new file mode 100644 index 0000000000..ac8a918a8a --- /dev/null +++ b/test/external/mlperf_unet3d/kits19.py @@ -0,0 +1,165 @@ +# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py + +import random +import numpy as np +import scipy.ndimage +from torch.utils.data import Dataset +from torchvision import transforms + + +def get_train_transforms(): + rand_flip = RandFlip() + cast = Cast(types=(np.float32, np.uint8)) + rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1) + rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1) + train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise]) + return train_transforms + + +class RandBalancedCrop: + def __init__(self, patch_size, oversampling): + self.patch_size = patch_size + self.oversampling = oversampling + + def __call__(self, data): + image, label = data["image"], data["label"] + if random.random() < self.oversampling: + image, label, cords = self.rand_foreg_cropd(image, label) + else: + image, label, cords = self._rand_crop(image, label) + data.update({"image": image, "label": label}) + return data + + @staticmethod + def randrange(max_range): + return 0 if max_range == 0 else random.randrange(max_range) + + def get_cords(self, cord, idx): + return cord[idx], cord[idx] + self.patch_size[idx] + + def _rand_crop(self, image, label): + ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)] + cord = [self.randrange(x) for x in ranges] + low_x, high_x = self.get_cords(cord, 0) + low_y, high_y = self.get_cords(cord, 1) + low_z, high_z = self.get_cords(cord, 2) + image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] + label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] + return image, label, [low_x, high_x, low_y, high_y, low_z, high_z] + + def rand_foreg_cropd(self, image, label): + def adjust(foreg_slice, patch_size, label, idx): + diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start) + sign = -1 if diff < 0 else 1 + diff = abs(diff) + ladj = self.randrange(diff) + hadj = diff - ladj + low = max(0, foreg_slice[idx].start - sign * ladj) + high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj) + diff = patch_size[idx - 1] - (high - low) + if diff > 0 and low == 0: + high += diff + elif diff > 0: + low -= diff + return low, high + + cl = np.random.choice(np.unique(label[label > 0])) + foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0]) + foreg_slices = [x for x in foreg_slices if x is not None] + slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices] + slice_idx = np.argsort(slice_volumes)[-2:] + foreg_slices = [foreg_slices[i] for i in slice_idx] + if not foreg_slices: + return self._rand_crop(image, label) + foreg_slice = foreg_slices[random.randrange(len(foreg_slices))] + low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1) + low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2) + low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3) + image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] + label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] + return image, label, [low_x, high_x, low_y, high_y, low_z, high_z] + + +class RandFlip: + def __init__(self): + self.axis = [1, 2, 3] + self.prob = 1 / len(self.axis) + + def flip(self, data, axis): + data["image"] = np.flip(data["image"], axis=axis).copy() + data["label"] = np.flip(data["label"], axis=axis).copy() + return data + + def __call__(self, data): + for axis in self.axis: + if random.random() < self.prob: + data = self.flip(data, axis) + return data + + +class Cast: + def __init__(self, types): + self.types = types + + def __call__(self, data): + data["image"] = data["image"].astype(self.types[0]) + data["label"] = data["label"].astype(self.types[1]) + return data + + +class RandomBrightnessAugmentation: + def __init__(self, factor, prob): + self.prob = prob + self.factor = factor + + def __call__(self, data): + image = data["image"] + if random.random() < self.prob: + factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1) + image = (image * (1 + factor)).astype(image.dtype) + data.update({"image": image}) + return data + + +class GaussianNoise: + def __init__(self, mean, std, prob): + self.mean = mean + self.std = std + self.prob = prob + + def __call__(self, data): + image = data["image"] + if random.random() < self.prob: + scale = np.random.uniform(low=0.0, high=self.std) + noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype) + data.update({"image": image + noise}) + return data + + +class PytTrain(Dataset): + def __init__(self, images, labels, **kwargs): + self.images, self.labels = images, labels + self.train_transforms = get_train_transforms() + patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"] + self.patch_size = patch_size + self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])} + data = self.rand_crop(data) + data = self.train_transforms(data) + return data["image"], data["label"] + + +class PytVal(Dataset): + def __init__(self, images, labels): + self.images, self.labels = images, labels + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + return np.load(self.images[idx]), np.load(self.labels[idx]) \ No newline at end of file