mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
[MLPerf] UNet3D dataloader (#4343)
* add support for train/val datasets for kits19 * split dataset into train and val sets * add tests for kits19 dataloader * add MLPerf dataset tests to CI * update unet3d model_eval script * fix linting * add nibabel * fix how mock dataset gets created * update ref implementation with permalink and no edits * clean up test and update rand_flip implementation * cleanups
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -232,6 +232,9 @@ jobs:
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf metrics
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf datasets
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test THREEFRY
|
||||
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py
|
||||
|
||||
@@ -63,13 +63,13 @@ def eval_resnet():
|
||||
def eval_unet3d():
|
||||
# UNet3D
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference, get_val_files
|
||||
from examples.mlperf.metrics import dice_score
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
st = time.perf_counter()
|
||||
for i, (image, label) in enumerate(iterate(), start=1):
|
||||
for i, (image, label) in enumerate(iterate(get_val_files()), start=1):
|
||||
mt = time.perf_counter()
|
||||
pred, label = sliding_window_inference(mdl, image, label)
|
||||
et = time.perf_counter()
|
||||
|
||||
@@ -3,13 +3,16 @@ import functools
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import nibabel as nib
|
||||
from scipy import signal
|
||||
from scipy import signal, ndimage
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import fetch
|
||||
|
||||
BASEDIR = Path(__file__).parent / "kits19" / "data"
|
||||
PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed"
|
||||
|
||||
"""
|
||||
To download the dataset:
|
||||
@@ -23,6 +26,10 @@ mv kits19 extra/datasets
|
||||
```
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_train_files():
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_val_files():
|
||||
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt").read_text()
|
||||
@@ -65,15 +72,42 @@ def preprocess(file_path):
|
||||
image, label = pad_to_min_shape(image, label)
|
||||
return image, label
|
||||
|
||||
def iterate(val=True, shuffle=False):
|
||||
if not val: raise NotImplementedError
|
||||
files = get_val_files()
|
||||
def preprocess_dataset(filenames, preprocessed_dir, val):
|
||||
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
|
||||
if not preprocessed_dataset_dir.is_dir(): os.makedirs(preprocessed_dataset_dir)
|
||||
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
|
||||
case = os.path.basename(fn)
|
||||
image, label = preprocess(fn)
|
||||
image, label = image.astype(np.float32), label.astype(np.uint8)
|
||||
np.save(preprocessed_dataset_dir / f"{case}_x.npy", image, allow_pickle=False)
|
||||
np.save(preprocessed_dataset_dir / f"{case}_y.npy", label, allow_pickle=False)
|
||||
|
||||
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
|
||||
order = list(range(0, len(files)))
|
||||
preprocessed_dataset_dir = (preprocessed_dir / ("val" if val else "train")) if preprocessed_dir is not None else None
|
||||
if shuffle: random.shuffle(order)
|
||||
for file in files:
|
||||
X, Y = preprocess(file)
|
||||
X = np.expand_dims(X, axis=0)
|
||||
yield (X, Y)
|
||||
for i in range(0, len(files), bs):
|
||||
samples = []
|
||||
for i in order[i:i+bs]:
|
||||
if preprocessed_dataset_dir is not None:
|
||||
x_cached_path, y_cached_path = preprocessed_dataset_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dataset_dir / f"{os.path.basename(files[i])}_y.npy"
|
||||
if x_cached_path.exists() and y_cached_path.exists():
|
||||
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
|
||||
else: samples += [preprocess(files[i])]
|
||||
X, Y = [x[0] for x in samples], [x[1] for x in samples]
|
||||
if val:
|
||||
yield X[0][None], Y[0]
|
||||
else:
|
||||
X_preprocessed, Y_preprocessed = [], []
|
||||
for x, y in zip(X, Y):
|
||||
x, y = rand_balanced_crop(x, y)
|
||||
x, y = rand_flip(x, y)
|
||||
x, y = x.astype(np.float32), y.astype(np.uint8)
|
||||
x = random_brightness_augmentation(x)
|
||||
x = gaussian_noise(x)
|
||||
X_preprocessed.append(x)
|
||||
Y_preprocessed.append(y)
|
||||
yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
|
||||
|
||||
def gaussian_kernel(n, std):
|
||||
gaussian_1d = signal.windows.gaussian(n, std)
|
||||
@@ -126,6 +160,73 @@ def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), o
|
||||
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
||||
return result, labels
|
||||
|
||||
def rand_flip(image, label, axis=(1, 2, 3)):
|
||||
prob = 1 / len(axis)
|
||||
for ax in axis:
|
||||
if random.random() < prob:
|
||||
image = np.flip(image, axis=ax).copy()
|
||||
label = np.flip(label, axis=ax).copy()
|
||||
return image, label
|
||||
|
||||
def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
|
||||
if random.random() < prob:
|
||||
factor = np.random.uniform(low=low, high=high, size=1)
|
||||
image = (image * (1 + factor)).astype(image.dtype)
|
||||
return image
|
||||
|
||||
def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
|
||||
if random.random() < prob:
|
||||
scale = np.random.uniform(low=0.0, high=std)
|
||||
noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
|
||||
image += noise
|
||||
return image
|
||||
|
||||
def _rand_foreg_cropb(image, label, patch_size):
|
||||
def adjust(foreg_slice, label, idx):
|
||||
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
|
||||
sign = -1 if diff < 0 else 1
|
||||
diff = abs(diff)
|
||||
ladj = 0 if diff == 0 else random.randrange(diff)
|
||||
hadj = diff - ladj
|
||||
low = max(0, foreg_slice[idx].start - sign * ladj)
|
||||
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
|
||||
diff = patch_size[idx - 1] - (high - low)
|
||||
if diff > 0 and low == 0: high += diff
|
||||
elif diff > 0: low -= diff
|
||||
return low, high
|
||||
|
||||
cl = np.random.choice(np.unique(label[label > 0]))
|
||||
foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
|
||||
foreg_slices = [x for x in foreg_slices if x is not None]
|
||||
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
|
||||
slice_idx = np.argsort(slice_volumes)[-2:]
|
||||
foreg_slices = [foreg_slices[i] for i in slice_idx]
|
||||
if not foreg_slices: return _rand_crop(image, label)
|
||||
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
|
||||
low_x, high_x = adjust(foreg_slice, label, 1)
|
||||
low_y, high_y = adjust(foreg_slice, label, 2)
|
||||
low_z, high_z = adjust(foreg_slice, label, 3)
|
||||
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
return image, label
|
||||
|
||||
def _rand_crop(image, label, patch_size):
|
||||
ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
|
||||
cord = [0 if x == 0 else random.randrange(x) for x in ranges]
|
||||
low_x, high_x = cord[0], cord[0] + patch_size[0]
|
||||
low_y, high_y = cord[1], cord[1] + patch_size[1]
|
||||
low_z, high_z = cord[2], cord[2] + patch_size[2]
|
||||
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
return image, label
|
||||
|
||||
def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
|
||||
if random.random() < oversampling:
|
||||
image, label = _rand_foreg_cropb(image, label, patch_size)
|
||||
else:
|
||||
image, label = _rand_crop(image, label, patch_size)
|
||||
return image, label
|
||||
|
||||
if __name__ == "__main__":
|
||||
for X, Y in iterate():
|
||||
for X, Y in iterate(get_val_files()):
|
||||
print(X.shape, Y.shape)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -53,6 +53,7 @@ setup(name='tinygrad',
|
||||
"librosa",
|
||||
"networkx",
|
||||
"hypothesis",
|
||||
"nibabel",
|
||||
],
|
||||
'docs': [
|
||||
"mkdocs-material",
|
||||
|
||||
71
test/external/external_test_datasets.py
vendored
Normal file
71
test/external/external_test_datasets.py
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
from extra.datasets.kits19 import iterate, preprocess
|
||||
from test.external.mlperf_unet3d.kits19 import PytTrain, PytVal
|
||||
from tinygrad.helpers import temp
|
||||
from pathlib import Path
|
||||
|
||||
import nibabel as nib
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
class ExternalTestDatasets(unittest.TestCase):
|
||||
def _set_seed(self):
|
||||
np.random.seed(42)
|
||||
random.seed(42)
|
||||
|
||||
def _create_sample(self):
|
||||
self._set_seed()
|
||||
|
||||
img, lbl = np.random.rand(190, 392, 392).astype(np.float32), np.random.randint(0, 100, size=(190, 392, 392)).astype(np.uint8)
|
||||
img, lbl = nib.Nifti1Image(img, np.eye(4)), nib.Nifti1Image(lbl, np.eye(4))
|
||||
|
||||
os.makedirs(tempfile.gettempdir() + "/case_0000", exist_ok=True)
|
||||
nib.save(img, temp("case_0000/imaging.nii.gz"))
|
||||
nib.save(lbl, temp("case_0000/segmentation.nii.gz"))
|
||||
|
||||
return Path(tempfile.gettempdir()) / "case_0000"
|
||||
|
||||
def _create_kits19_ref_sample(self, sample_pth, val):
|
||||
self._set_seed()
|
||||
|
||||
img, lbl = preprocess(sample_pth)
|
||||
dataset = "val" if val else "train"
|
||||
preproc_img_pth, preproc_lbl_pth = temp(f"{dataset}/case_0000_x.npy"), temp(f"{dataset}/case_0000_y.npy")
|
||||
|
||||
os.makedirs(tempfile.gettempdir() + f"/{dataset}", exist_ok=True)
|
||||
np.save(preproc_img_pth, img, allow_pickle=False)
|
||||
np.save(preproc_lbl_pth, lbl, allow_pickle=False)
|
||||
|
||||
if val:
|
||||
dataset = PytVal([preproc_img_pth], [preproc_lbl_pth])
|
||||
else:
|
||||
dataset = PytTrain([preproc_img_pth], [preproc_lbl_pth], patch_size=(128, 128, 128), oversampling=0.4)
|
||||
|
||||
return dataset[0]
|
||||
|
||||
def _create_kits19_tinygrad_sample(self, sample_pth, val):
|
||||
self._set_seed()
|
||||
return next(iterate([sample_pth], val=val))
|
||||
|
||||
def test_kits19_training_set(self):
|
||||
sample_pth = self._create_sample()
|
||||
|
||||
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, False)
|
||||
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, False)
|
||||
|
||||
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
|
||||
np.testing.assert_equal(tinygrad_lbl[:, 0], ref_lbl)
|
||||
|
||||
def test_kits19_validation_set(self):
|
||||
sample_pth = self._create_sample()
|
||||
|
||||
tinygrad_img, tinygrad_lbl = self._create_kits19_tinygrad_sample(sample_pth, True)
|
||||
ref_img, ref_lbl = self._create_kits19_ref_sample(sample_pth, True)
|
||||
|
||||
np.testing.assert_equal(tinygrad_img[:, 0], ref_img)
|
||||
np.testing.assert_equal(tinygrad_lbl, ref_lbl)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
165
test/external/mlperf_unet3d/kits19.py
vendored
Normal file
165
test/external/mlperf_unet3d/kits19.py
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
# copied from https://github.com/mlcommons/training/blob/5c08ce57e7f582cc4558035d8324a2bf4c8ca225/image_segmentation/pytorch/data_loading/pytorch_loader.py
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def get_train_transforms():
|
||||
rand_flip = RandFlip()
|
||||
cast = Cast(types=(np.float32, np.uint8))
|
||||
rand_scale = RandomBrightnessAugmentation(factor=0.3, prob=0.1)
|
||||
rand_noise = GaussianNoise(mean=0.0, std=0.1, prob=0.1)
|
||||
train_transforms = transforms.Compose([rand_flip, cast, rand_scale, rand_noise])
|
||||
return train_transforms
|
||||
|
||||
|
||||
class RandBalancedCrop:
|
||||
def __init__(self, patch_size, oversampling):
|
||||
self.patch_size = patch_size
|
||||
self.oversampling = oversampling
|
||||
|
||||
def __call__(self, data):
|
||||
image, label = data["image"], data["label"]
|
||||
if random.random() < self.oversampling:
|
||||
image, label, cords = self.rand_foreg_cropd(image, label)
|
||||
else:
|
||||
image, label, cords = self._rand_crop(image, label)
|
||||
data.update({"image": image, "label": label})
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def randrange(max_range):
|
||||
return 0 if max_range == 0 else random.randrange(max_range)
|
||||
|
||||
def get_cords(self, cord, idx):
|
||||
return cord[idx], cord[idx] + self.patch_size[idx]
|
||||
|
||||
def _rand_crop(self, image, label):
|
||||
ranges = [s - p for s, p in zip(image.shape[1:], self.patch_size)]
|
||||
cord = [self.randrange(x) for x in ranges]
|
||||
low_x, high_x = self.get_cords(cord, 0)
|
||||
low_y, high_y = self.get_cords(cord, 1)
|
||||
low_z, high_z = self.get_cords(cord, 2)
|
||||
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
|
||||
|
||||
def rand_foreg_cropd(self, image, label):
|
||||
def adjust(foreg_slice, patch_size, label, idx):
|
||||
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
|
||||
sign = -1 if diff < 0 else 1
|
||||
diff = abs(diff)
|
||||
ladj = self.randrange(diff)
|
||||
hadj = diff - ladj
|
||||
low = max(0, foreg_slice[idx].start - sign * ladj)
|
||||
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
|
||||
diff = patch_size[idx - 1] - (high - low)
|
||||
if diff > 0 and low == 0:
|
||||
high += diff
|
||||
elif diff > 0:
|
||||
low -= diff
|
||||
return low, high
|
||||
|
||||
cl = np.random.choice(np.unique(label[label > 0]))
|
||||
foreg_slices = scipy.ndimage.find_objects(scipy.ndimage.measurements.label(label==cl)[0])
|
||||
foreg_slices = [x for x in foreg_slices if x is not None]
|
||||
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
|
||||
slice_idx = np.argsort(slice_volumes)[-2:]
|
||||
foreg_slices = [foreg_slices[i] for i in slice_idx]
|
||||
if not foreg_slices:
|
||||
return self._rand_crop(image, label)
|
||||
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
|
||||
low_x, high_x = adjust(foreg_slice, self.patch_size, label, 1)
|
||||
low_y, high_y = adjust(foreg_slice, self.patch_size, label, 2)
|
||||
low_z, high_z = adjust(foreg_slice, self.patch_size, label, 3)
|
||||
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
||||
return image, label, [low_x, high_x, low_y, high_y, low_z, high_z]
|
||||
|
||||
|
||||
class RandFlip:
|
||||
def __init__(self):
|
||||
self.axis = [1, 2, 3]
|
||||
self.prob = 1 / len(self.axis)
|
||||
|
||||
def flip(self, data, axis):
|
||||
data["image"] = np.flip(data["image"], axis=axis).copy()
|
||||
data["label"] = np.flip(data["label"], axis=axis).copy()
|
||||
return data
|
||||
|
||||
def __call__(self, data):
|
||||
for axis in self.axis:
|
||||
if random.random() < self.prob:
|
||||
data = self.flip(data, axis)
|
||||
return data
|
||||
|
||||
|
||||
class Cast:
|
||||
def __init__(self, types):
|
||||
self.types = types
|
||||
|
||||
def __call__(self, data):
|
||||
data["image"] = data["image"].astype(self.types[0])
|
||||
data["label"] = data["label"].astype(self.types[1])
|
||||
return data
|
||||
|
||||
|
||||
class RandomBrightnessAugmentation:
|
||||
def __init__(self, factor, prob):
|
||||
self.prob = prob
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self, data):
|
||||
image = data["image"]
|
||||
if random.random() < self.prob:
|
||||
factor = np.random.uniform(low=1.0-self.factor, high=1.0+self.factor, size=1)
|
||||
image = (image * (1 + factor)).astype(image.dtype)
|
||||
data.update({"image": image})
|
||||
return data
|
||||
|
||||
|
||||
class GaussianNoise:
|
||||
def __init__(self, mean, std, prob):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, data):
|
||||
image = data["image"]
|
||||
if random.random() < self.prob:
|
||||
scale = np.random.uniform(low=0.0, high=self.std)
|
||||
noise = np.random.normal(loc=self.mean, scale=scale, size=image.shape).astype(image.dtype)
|
||||
data.update({"image": image + noise})
|
||||
return data
|
||||
|
||||
|
||||
class PytTrain(Dataset):
|
||||
def __init__(self, images, labels, **kwargs):
|
||||
self.images, self.labels = images, labels
|
||||
self.train_transforms = get_train_transforms()
|
||||
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
|
||||
self.patch_size = patch_size
|
||||
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
|
||||
data = self.rand_crop(data)
|
||||
data = self.train_transforms(data)
|
||||
return data["image"], data["label"]
|
||||
|
||||
|
||||
class PytVal(Dataset):
|
||||
def __init__(self, images, labels):
|
||||
self.images, self.labels = images, labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return np.load(self.images[idx]), np.load(self.labels[idx])
|
||||
Reference in New Issue
Block a user