[MLPerf] UNet3D dataloader (#4343)

* add support for train/val datasets for kits19

* split dataset into train and val sets

* add tests for kits19 dataloader

* add MLPerf dataset tests to CI

* update unet3d model_eval script

* fix linting

* add nibabel

* fix how mock dataset gets created

* update ref implementation with permalink and no edits

* clean up test and update rand_flip implementation

* cleanups
This commit is contained in:
Francis Lata
2024-04-28 22:34:18 -04:00
committed by GitHub
parent 82d0ed3cf3
commit bb849a57d1
6 changed files with 352 additions and 11 deletions

View File

@@ -232,6 +232,9 @@ jobs:
- if: ${{ matrix.task == 'onnx' }} - if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf metrics name: Test MLPerf metrics
run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20 run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf datasets
run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
- if: ${{ matrix.task == 'onnx' }} - if: ${{ matrix.task == 'onnx' }}
name: Test THREEFRY name: Test THREEFRY
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py

View File

@@ -63,13 +63,13 @@ def eval_resnet():
def eval_unet3d(): def eval_unet3d():
# UNet3D # UNet3D
from extra.models.unet3d import UNet3D from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, sliding_window_inference from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files
from examples.mlperf.metrics import dice_score from examples.mlperf.metrics import dice_score
mdl = UNet3D() mdl = UNet3D()
mdl.load_from_pretrained() mdl.load_from_pretrained()
s = 0 s = 0
st = time.perf_counter() st = time.perf_counter()
for i, (image, label) in enumerate(iterate(), start=1): for i, (image, label) in enumerate(iterate(get_val_files()), start=1):
mt = time.perf_counter() mt = time.perf_counter()
pred, label = sliding_window_inference(mdl, image, label) pred, label = sliding_window_inference(mdl, image, label)
et = time.perf_counter() et = time.perf_counter()

View File

@@ -3,13 +3,16 @@ import functools
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import nibabel as nib import nibabel as nib
from scipy import signal from scipy import signal, ndimage
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch from tinygrad.helpers import fetch
BASEDIR = Path(__file__).parent / "kits19" / "data" BASEDIR = Path(__file__).parent / "kits19" / "data"
PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
""" """
To download the dataset: To download the dataset:
@@ -23,6 +26,10 @@ mv kits19 extra/datasets
``` ```
""" """
@functools.lru_cache(None)
def get_train_files():
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
@functools.lru_cache(None) @functools.lru_cache(None)
def get_val_files(): def get_val_files():
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text() data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
@@ -65,15 +72,42 @@ def preprocess(file_path):
image, label = pad_to_min_shape(image, label) image, label = pad_to_min_shape(image, label)
return image, label return image, label
def iterate(val=True, shuffle=False): def preprocess_dataset(filenames, preprocessed_dir, val):
if not val: raise NotImplementedError preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
files = get_val_files() if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
case = os.path.basename(fn)
image, label = preprocess(fn)
image, label = image.astype(np.float32), label.astype(np.uint8)
np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
order = list(range(0, len(files))) order = list(range(0, len(files)))
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
if shuffle: random.shuffle(order) if shuffle: random.shuffle(order)
for file in files: for i in range(0, len(files), bs):
X, Y = preprocess(file) samples = []
X = np.expand_dims(X, axis=0) for i in order[i:i+bs]:
yield (X, Y) if preprocessed_dataset_dir is not None:
x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
if x_cached_path.exists() and y_cached_path.exists():
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
else: samples += [preprocess(files[i])]
X, Y = [x[0] for x in samples], [x[1] for x in samples]
if val:
yield X[0][None], Y[0]
else:
X_preprocessed, Y_preprocessed = [], []
for x, y in zip(X, Y):
x, y = rand_balanced_crop(x, y)
x, y = rand_flip(x, y)
x, y = x.astype(np.float32), y.astype(np.uint8)
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X_preprocessed.append(x)
Y_preprocessed.append(y)
yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
def gaussian_kernel(n, std): def gaussian_kernel(n, std):
gaussian_1d = signal.windows.gaussian(n, std) gaussian_1d = signal.windows.gaussian(n, std)
@@ -126,6 +160,73 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]] result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
return result, labels return result, labels
def rand_flip(image, label, axis=(1, 2, 3)):
prob = 1 / len(axis)
for ax in axis:
if random.random() < prob:
image = np.flip(image, axis=ax).copy()
label = np.flip(label, axis=ax).copy()
return image, label
def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
if random.random() < prob:
factor = np.random.uniform(low=low, high=high, size=1)
image = (image * (1 + factor)).astype(image.dtype)
return image
def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
if random.random() < prob:
scale = np.random.uniform(low=0.0, high=std)
noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
image += noise
return image
def _rand_foreg_cropb(image, label, patch_size):
def adjust(foreg_slice, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = 0 if diff == 0 else random.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0: high += diff
elif diff > 0: low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices: return _rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, label, 1)
low_y, high_y = adjust(foreg_slice, label, 2)
low_z, high_z = adjust(foreg_slice, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def _rand_crop(image, label, patch_size):
ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
cord = [0 if x == 0 else random.randrange(x) for x in ranges]
low_x, high_x = cord[0], cord[0] + patch_size[0]
low_y, high_y = cord[1], cord[1] + patch_size[1]
low_z, high_z = cord[2], cord[2] + patch_size[2]
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
if random.random() < oversampling:
image, label = _rand_foreg_cropb(image, label, patch_size)
else:
image, label = _rand_crop(image, label, patch_size)
return image, label
if __name__ == "__main__": if __name__ == "__main__":
for X, Y in iterate(): for X, Y in iterate(get_val_files()):
print(X.shape, Y.shape) print(X.shape, Y.shape)

View File

@@ -53,6 +53,7 @@ setup(name='tinygrad',
"librosa", "librosa",
"networkx", "networkx",
"hypothesis", "hypothesis",
"nibabel",
], ],
'docs': [ 'docs': [
"mkdocs-material", "mkdocs-material",

71
test/external/external_test_datasets.py vendored Normal file
View File

@@ -0,0 +1,71 @@
from extra.datasets.kits19 import iterate, preprocess
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
from tinygrad.helpers import temp
from pathlib import Path
import nibabel as nib
import numpy as np
import os
import random
import tempfile
import unittest
class ExternalTestDatasets(unittest.TestCase):
def _set_seed(self):
np.random.seed(42)
random.seed(42)
def _create_sample(self):
self._set_seed()
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True)
nib.save(img, temp("case_0000/imaging.nii.gz"))
nib.save(lbl, temp("case_0000/segmentation.nii.gz"))
return Path(tempfile.gettempdir()) / "case_0000"
def _create_kits19_ref_sample(self, sample_pth, val):
self._set_seed()
img, lbl = preprocess(sample_pth)
dataset = "val" if val else "train"
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy")
os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True)
np.save(preproc_img_pth, img, allow_pickle=False)
np.save(preproc_lbl_pth, lbl, allow_pickle=False)
if val:
dataset = PytVal([preproc_img_pth], [preproc_lbl_pth])
else:
dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4)
return dataset[0]
def _create_kits19_tinygrad_sample(self, sample_pth, val):
self._set_seed()
return next(iterate([sample_pth], val=val))
def test_kits19_training_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl)
def test_kits19_validation_set(self):
sample_pth = self._create_sample()
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True)
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True)
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
np.testing.assert_equal(tinygrad_lbl, ref_lbl)
if __name__ == '__main__':
unittest.main()

165
test/external/mlperf_unet3d/kits19.py vendored Normal file
View File

@@ -0,0 +1,165 @@
# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
import random
import numpy as np
import scipy.ndimage
from torch.utils.data import Dataset
from torchvision import transforms
def get_train_transforms():
rand_flip = RandFlip()
cast = Cast(types=(np.float32, np.uint8))
rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
return train_transforms
class RandBalancedCrop:
def __init__(self, patch_size, oversampling):
self.patch_size = patch_size
self.oversampling = oversampling
def __call__(self, data):
image, label = data["image"], data["label"]
if random.random() < self.oversampling:
image, label, cords = self.rand_foreg_cropd(image, label)
else:
image, label, cords = self._rand_crop(image, label)
data.update({"image": image, "label": label})
return data
@staticmethod
def randrange(max_range):
return 0 if max_range == 0 else random.randrange(max_range)
def get_cords(self, cord, idx):
return cord[idx], cord[idx] + self.patch_size[idx]
def _rand_crop(self, image, label):
ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
cord = [self.randrange(x) for x in ranges]
low_x, high_x = self.get_cords(cord, 0)
low_y, high_y = self.get_cords(cord, 1)
low_z, high_z = self.get_cords(cord, 2)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
def rand_foreg_cropd(self, image, label):
def adjust(foreg_slice, patch_size, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = self.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0:
high += diff
elif diff > 0:
low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices:
return self._rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
class RandFlip:
def __init__(self):
self.axis = [1, 2, 3]
self.prob = 1 / len(self.axis)
def flip(self, data, axis):
data["image"] = np.flip(data["image"], axis=axis).copy()
data["label"] = np.flip(data["label"], axis=axis).copy()
return data
def __call__(self, data):
for axis in self.axis:
if random.random() < self.prob:
data = self.flip(data, axis)
return data
class Cast:
def __init__(self, types):
self.types = types
def __call__(self, data):
data["image"] = data["image"].astype(self.types[0])
data["label"] = data["label"].astype(self.types[1])
return data
class RandomBrightnessAugmentation:
def __init__(self, factor, prob):
self.prob = prob
self.factor = factor
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
image = (image * (1 + factor)).astype(image.dtype)
data.update({"image": image})
return data
class GaussianNoise:
def __init__(self, mean, std, prob):
self.mean = mean
self.std = std
self.prob = prob
def __call__(self, data):
image = data["image"]
if random.random() < self.prob:
scale = np.random.uniform(low=0.0, high=self.std)
noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
data.update({"image": image + noise})
return data
class PytTrain(Dataset):
def __init__(self, images, labels, **kwargs):
self.images, self.labels = images, labels
self.train_transforms = get_train_transforms()
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
self.patch_size = patch_size
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
data = self.rand_crop(data)
data = self.train_transforms(data)
return data["image"], data["label"]
class PytVal(Dataset):
def __init__(self, images, labels):
self.images, self.labels = images, labels
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return np.load(self.images[idx]), np.load(self.labels[idx])