[MLPerf] UNet3D dataloader (#4343)

* add support for train/val datasets for kits19

* split dataset into train and val sets

* add tests for kits19 dataloader

* add MLPerf dataset tests to CI

* update unet3d model_eval script

* fix linting

* add nibabel

* fix how mock dataset gets created

* update ref implementation with permalink and no edits

* clean up test and update rand_flip implementation

* cleanups
This commit is contained in:
Francis Lata
2024-04-28 22:34:18 -04:00
committed by GitHub
parent 82d0ed3cf3
commit bb849a57d1
6 changed files with 352 additions and 11 deletions

View File

@@ -232,6 +232,9 @@ jobs:
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf metrics
run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf datasets
run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test THREEFRY
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py

View File

@@ -63,13 +63,13 @@ def eval_resnet():
def eval_unet3d():
# UNet3D
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, sliding_window_inference
from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files
from examples.mlperf.metrics import dice_score
mdl = UNet3D()
mdl.load_from_pretrained()
s = 0
st = time.perf_counter()
for i, (image, label) in enumerate(iterate(), start=1):
for i, (image, label) in enumerate(iterate(get_val_files()), start=1):
mt = time.perf_counter()
pred, label = sliding_window_inference(mdl, image, label)
et = time.perf_counter()

View File

@@ -3,13 +3,16 @@ import functools
from pathlib import Path
import numpy as np
import nibabel as nib
from scipy import signal
from scipy import signal, ndimage
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch
BASEDIR = Path(__file__).parent / "kits19" / "data"
PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
"""
To download the dataset:
@@ -23,6 +26,10 @@ mv kits19 extra/datasets
```
"""
@functools.lru_cache(None)
def get_train_files():
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
@functools.lru_cache(None)
def get_val_files():
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
@@ -65,15 +72,42 @@ def preprocess(file_path):
image, label = pad_to_min_shape(image, label)
return image, label
def iterate(val=True, shuffle=False):
if not val: raise NotImplementedError
files = get_val_files()
def preprocess_dataset(filenames, preprocessed_dir, val):
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
case = os.path.basename(fn)
image, label = preprocess(fn)
image, label = image.astype(np.float32), label.astype(np.uint8)
np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
order = list(range(0, len(files)))
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
if shuffle: random.shuffle(order)
for file in files:
X, Y = preprocess(file)
X = np.expand_dims(X, axis=0)
yield (X, Y)
for i in range(0, len(files), bs):
samples = []
for i in order[i:i+bs]:
if preprocessed_dataset_dir is not None:
x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
if x_cached_path.exists() and y_cached_path.exists():
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
else: samples += [preprocess(files[i])]
X, Y = [x[0] for x in samples], [x[1] for x in samples]
if val:
yield X[0][None], Y[0]
else:
X_preprocessed, Y_preprocessed = [], []
for x, y in zip(X, Y):
x, y = rand_balanced_crop(x, y)
x, y = rand_flip(x, y)
x, y = x.astype(np.float32), y.astype(np.uint8)
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X_preprocessed.append(x)
Y_preprocessed.append(y)
yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
def gaussian_kernel(n, std):
gaussian_1d = signal.windows.gaussian(n, std)
@@ -126,6 +160,73 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
return result, labels
def rand_flip(image, label, axis=(1, 2, 3)):
prob = 1 / len(axis)
for ax in axis:
if random.random() < prob:
image = np.flip(image, axis=ax).copy()
label = np.flip(label, axis=ax).copy()
return image, label
def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
if random.random() < prob:
factor = np.random.uniform(low=low, high=high, size=1)
image = (image * (1 + factor)).astype(image.dtype)
return image
def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
if random.random() < prob:
scale = np.random.uniform(low=0.0, high=std)
noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
image += noise
return image
def _rand_foreg_cropb(image, label, patch_size):
def adjust(foreg_slice, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = 0 if diff == 0 else random.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0: high += diff
elif diff > 0: low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices: return _rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, label, 1)
low_y, high_y = adjust(foreg_slice, label, 2)
low_z, high_z = adjust(foreg_slice, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def _rand_crop(image, label, patch_size):
ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
cord = [0 if x == 0 else random.randrange(x) for x in ranges]
low_x, high_x = cord[0], cord[0] + patch_size[0]
low_y, high_y = cord[1], cord[1] + patch_size[1]
low_z, high_z = cord[2], cord[2] + patch_size[2]
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
if random.random() < oversampling:
image, label = _rand_foreg_cropb(image, label, patch_size)
else:
image, label = _rand_crop(image, label, patch_size)
return image, label
if __name__ == "__main__":
for X, Y in iterate():
for X, Y in iterate(get_val_files()):
print(X.shape, Y.shape)

View File

@@ -53,6 +53,7 @@ setup(name='tinygrad',
"librosa",
"networkx",
"hypothesis",
"nibabel",
],
'docs': [
"mkdocs-material",

71
test/external/external_test_datasets.py vendored Normal file
View File

@@ -0,0 +1,71 @@
from extra.datasets.kits19 import iterate, preprocess
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
from tinygrad.helpers import temp
from pathlib import Path
import nibabel as nib
import numpy as np
import os
import random
import tempfile
import unittest
class ExternalTestDatasets(unittest.TestCase):
def _set_seed(self):
np.random.seed(42)
random.seed(42)
def _create_sample(self):
self._set_seed()
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True)
nib.save(img, temp("case_0000/imaging.nii.gz"))
nib.save(lbl, temp("case_0000/segmentation.nii.gz"))
return Path(tempfile.gettempdir()) / "case_0000"
def _create_kits19_ref_sample(self, sample_pth, val):
self._set_seed()
img, lbl = preprocess(sample_pth)
dataset = "val" if val else "train"
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy")
os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True)
np.save(preproc_img_pth, img, allow_pickle=False)
np.save(preproc_lbl_pth, lbl, allow_pickle=False)
if val:
dataset = PytVal([preproc_img_pth], [preproc_lbl_pth])
else:
dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4)
return dataset[0]
def _create_kits19_tinygrad_sample(self, sample_pth, val):
self._set_seed()
return next(iterate([sample_pth], val=val))
def test_kits19_training_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl)
def test_kits19_validation_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl, ref_lbl)
if __name__ == '__main__':
unittest.main()

165
test/external/mlperf_unet3d/kits19.py vendored Normal file
View File

@@ -0,0 +1,165 @@
# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
import random
import numpy as np
import scipy.ndimage
from torch.utils.data import Dataset
from torchvision import transforms
def get_train_transforms():
rand_flip = RandFlip()
cast = Cast(types=(np.float32, np.uint8))
rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
return train_transforms
class RandBalancedCrop:
def __init__(self, patch_size, oversampling):
self.patch_size = patch_size
self.oversampling = oversampling
def __call__(self, data):
image, label = data["image"], data["label"]
if random.random() < self.oversampling:
image, label, cords = self.rand_foreg_cropd(image, label)
else:
image, label, cords = self._rand_crop(image, label)
data.update({"image": image, "label": label})
return data
@staticmethod
def randrange(max_range):
return 0 if max_range == 0 else random.randrange(max_range)
def get_cords(self, cord, idx):
return cord[idx], cord[idx] + self.patch_size[idx]
def _rand_crop(self, image, label):
ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
cord = [self.randrange(x) for x in ranges]
low_x, high_x = self.get_cords(cord, 0)
low_y, high_y = self.get_cords(cord, 1)
low_z, high_z = self.get_cords(cord, 2)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
def rand_foreg_cropd(self, image, label):
def adjust(foreg_slice, patch_size, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = self.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0:
high += diff
elif diff > 0:
low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices:
return self._rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
class RandFlip:
def __init__(self):
self.axis = [1, 2, 3]
self.prob = 1 / len(self.axis)
def flip(self, data, axis):
data["image"] = np.flip(data["image"], axis=axis).copy()
data["label"] = np.flip(data["label"], axis=axis).copy()
return data
def __call__(self, data):
for axis in self.axis:
if random.random() < self.prob:
data = self.flip(data, axis)
return data
class Cast:
def __init__(self, types):
self.types = types
def __call__(self, data):
data["image"] = data["image"].astype(self.types[0])
data["label"] = data["label"].astype(self.types[1])
return data
class RandomBrightnessAugmentation:
def __init__(self, factor, prob):
self.prob = prob
self.factor = factor
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
image = (image * (1 + factor)).astype(image.dtype)
data.update({"image": image})
return data
class GaussianNoise:
def __init__(self, mean, std, prob):
self.mean = mean
self.std = std
self.prob = prob
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
scale = np.random.uniform(low=0.0, high=self.std)
noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
data.update({"image": image + noise})
return data
class PytTrain(Dataset):
def __init__(self, images, labels, **kwargs):
self.images, self.labels = images, labels
self.train_transforms = get_train_transforms()
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
self.patch_size = patch_size
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
data = self.rand_crop(data)
data = self.train_transforms(data)
return data["image"], data["label"]
class PytVal(Dataset):
def __init__(self, images, labels):
self.images, self.labels = images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return np.load(self.images[idx]), np.load(self.labels[idx])