diff --git a/.gitignore b/.gitignore index 808dd71d92..cf65dfb79f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ disassemblers/cuda_ioctl_sniffer datasets/cifar-10-python.tar.gz datasets/librispeech/ datasets/imagenet/ +datasets/kits19/ datasets/squad/ datasets/img_align_celeba* datasets/open-images-v6-mlperf diff --git a/datasets/kits19.py b/datasets/kits19.py new file mode 100644 index 0000000000..bd9e0a5e33 --- /dev/null +++ b/datasets/kits19.py @@ -0,0 +1,131 @@ +import random +import functools +from pathlib import Path +import requests +import numpy as np +import nibabel as nib +from scipy import signal +import torch +import torch.nn.functional as F +from tinygrad.tensor import Tensor + +BASEDIR = Path(__file__).parent.parent.resolve() / "datasets" / "kits19" / "data" + +""" +To download the dataset: +```sh +git clone https://github.com/neheller/kits19 +cd kits19 +pip3 install -r requirements.txt +python3 -m starter_code.get_imaging +cd .. +mv kits datasets +``` +""" + +@functools.lru_cache(None) +def get_val_files(): + data = requests.get("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt") + return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.text.split("\n")]) + +def load_pair(file_path): + image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz") + image_spacings = image.header["pixdim"][1:4].tolist() + image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8) + image, label = np.expand_dims(image, 0), np.expand_dims(label, 0) + return image, label, image_spacings + +def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)): + if image_spacings != target_spacing: + spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:]) + new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist() + image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True) + label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest") + image = np.squeeze(image.numpy(), axis=0) + label = np.squeeze(label.numpy(), axis=0) + return image, label + +def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9): + image = np.clip(image, min_clip, max_clip) + image = (image - mean) / std + return image + +def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)): + current_shape = image.shape[1:] + bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)] + paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)] + image = np.pad(image, paddings, mode="edge") + label = np.pad(label, paddings, mode="edge") + return image, label + +def preprocess(file_path): + image, label, image_spacings = load_pair(file_path) + image, label = resample3d(image, label, image_spacings) + image = normal_intensity(image.copy()) + image, label = pad_to_min_shape(image, label) + return image, label + +def iterate(val=True, shuffle=False): + if not val: raise NotImplementedError + files = get_val_files() + order = list(range(0, len(files))) + if shuffle: random.shuffle(order) + for file in files: + X, Y = preprocess(file) + X = np.expand_dims(X, axis=0) + yield (X, Y) + +def gaussian_kernel(n, std): + gaussian_1d = signal.gaussian(n, std) + gaussian_2d = np.outer(gaussian_1d, gaussian_1d) + gaussian_3d = np.outer(gaussian_2d, gaussian_1d) + gaussian_3d = gaussian_3d.reshape(n, n, n) + gaussian_3d = np.cbrt(gaussian_3d) + gaussian_3d /= gaussian_3d.max() + return gaussian_3d + +def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3): + bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)] + bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)] + paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0] + return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings + +def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5): + from tinygrad.jit import TinyJit + mdl_run = TinyJit(lambda x: model(x).realize()) + image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:]) + strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)] + bounds = [image_shape[i] % strides[i] for i in range(dim)] + bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)] + inputs = inputs[ + ..., + bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), + bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), + bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), + ] + labels = labels[ + ..., + bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), + bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), + bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), + ] + inputs, paddings = pad_input(inputs, roi_shape, strides) + padded_shape = inputs.shape[2:] + size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] + result = np.zeros((1, 3, *padded_shape), dtype=np.float32) + norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32) + norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]) + norm_patch = np.expand_dims(norm_patch, axis=0) + for i in range(0, strides[0] * size[0], strides[0]): + for j in range(0, strides[1] * size[1], strides[1]): + for k in range(0, strides[2] * size[2], strides[2]): + out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy() + result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch + norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch + result /= norm_map + result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]] + return result, labels + +if __name__ == "__main__": + for X, Y in iterate(): + print(X.shape, Y.shape) diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index 454edca64d..fb773b7e61 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -1,5 +1,35 @@ from collections import OrderedDict import unicodedata +import numpy as np +from scipy import signal + +def gaussian_kernel(n, std): + gaussian_1d = signal.gaussian(n, std) + gaussian_2d = np.outer(gaussian_1d, gaussian_1d) + gaussian_3d = np.outer(gaussian_2d, gaussian_1d) + gaussian_3d = gaussian_3d.reshape(n, n, n) + gaussian_3d = np.cbrt(gaussian_3d) + gaussian_3d /= gaussian_3d.max() + return gaussian_3d + +def prepare_arrays(image, roi_shape=(128, 128, 128)): + assert len(roi_shape) == 3 and any(roi_shape) + image_shape = list(image.shape[2:]) + result = np.zeros((1, 3, *image_shape), dtype=image.dtype) + norm_map = np.zeros_like(result) + norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype) + return result, norm_map, norm_patch + +def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5): + assert len(roi_shape) == 3 and any(roi_shape) + assert 0 < overlap_factor < 1 + image_shape, dim = list(image.shape[2:]), len(image.shape[2:]) + strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)] + size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] + for i in range(0, strides[0] * size[0], strides[0]): + for j in range(0, strides[1] * size[1], strides[1]): + for k in range(0, strides[2] * size[2], strides[2]): + yield i, j, k def _get_best_indices(logits, n_best_size): index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) diff --git a/examples/mlperf/metrics.py b/examples/mlperf/metrics.py index e4ac9f9f15..9bf3399531 100644 --- a/examples/mlperf/metrics.py +++ b/examples/mlperf/metrics.py @@ -1,6 +1,7 @@ import re import string from collections import Counter +import numpy as np def levenshtein(a, b): n, m = len(a), len(b) @@ -28,6 +29,22 @@ def word_error_rate(x, y): scores += levenshtein(h_list, r_list) return float(scores) / words, float(scores), words +def one_hot(arr, num_classes=3): + res = np.eye(num_classes)[np.array(arr).reshape(-1)] + arr = res.reshape(list(arr.shape) + [num_classes]) + arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32) + return arr + +def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6): + channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape))) + prediction = prediction.argmax(axis=channel_axis) + prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:] + intersection = np.sum(prediction * target, axis=reduce_axis) + target_sum = np.sum(target, axis=reduce_axis) + prediction_sum = np.sum(prediction, axis=reduce_axis) + result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr) + return result[0] + def normalize_string(s): s = "".join(c for c in s.lower() if c not in string.punctuation) s = re.sub(r'\b(a|an|the)\b', ' ', s) diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index bc75b1da83..0c6b925447 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -4,6 +4,7 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.helpers import getenv +from examples.mlperf import helpers def eval_resnet(): # Resnet50-v1.5 @@ -41,6 +42,24 @@ def eval_resnet(): d += len(t) print(f"****** {n}/{d} {n*100.0/d:.2f}%") st = time.perf_counter() + +def eval_unet3d(): + # UNet3D + from models.unet3d import UNet3D + from datasets.kits19 import iterate, sliding_window_inference + from examples.mlperf.metrics import get_dice_score + mdl = UNet3D() + mdl.load_from_pretrained() + s = 0 + st = time.perf_counter() + for i, (image, label) in enumerate(iterate(), start=1): + mt = time.perf_counter() + pred, label = sliding_window_inference(mdl, image, label) + et = time.perf_counter() + print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model") + s += get_dice_score(pred, label).mean() + print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score") + st = time.perf_counter() def eval_retinanet(): # RetinaNet with ResNeXt50_32X4D diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index f11f76eb05..69ff7caec1 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -28,7 +28,8 @@ def spec_unet3d(): # 3D UNET from models.unet3d import UNet3D mdl = UNet3D() - img = Tensor.randn(1, 1, 5, 224, 224) + mdl.load_from_pretrained() + img = Tensor.randn(1, 1, 128, 128, 128) test_model(mdl, img) def spec_rnnt(): diff --git a/models/unet3d.py b/models/unet3d.py index 687e4dc74c..289ed4c86b 100644 --- a/models/unet3d.py +++ b/models/unet3d.py @@ -1,42 +1,59 @@ -# https://github.com/wolny/pytorch-3dunet from pathlib import Path -from extra.utils import download_file, fake_torch_load, get_child -import tinygrad.nn as nn +import torch +from tinygrad import nn +from tinygrad.tensor import Tensor +from extra.utils import download_file, get_child -class SingleConv: - def __init__(self, in_channels, out_channels): - self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group? - # TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False) - def __call__(self, x): - return self.conv(self.groupnorm(x)).relu() +class DownsampleBlock: + def __init__(self, c0, c1, stride=2): + self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] + self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] -class BasicModule: - def __init__(self, c0, c1, c2): - self.basic_module = {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)} def __call__(self, x): - return self.basic_module['SingleConv2'](self.basic_module['SingleConv1'](x)) + return x.sequential(self.conv1).sequential(self.conv2) + +class UpsampleBlock: + def __init__(self, c0, c1): + self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)] + self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] + self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu] + + def __call__(self, x, skip): + x = x.sequential(self.upsample_conv) + x = Tensor.cat(x, skip, dim=1) + return x.sequential(self.conv1).sequential(self.conv2) class UNet3D: - def __init__(self): - ups = [16,32,64,128,256] - self.encoders = [BasicModule(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)] - self.decoders = [BasicModule(ups[-1-i] + ups[-2-i], ups[-2-i], ups[-2-i]) for i in range(3)] - self.final_conv = nn.Conv2d(32, 1, (1,1,1)) + def __init__(self, in_channels=1, n_class=3): + filters = [32, 64, 128, 256, 320] + inp, out = filters[:-1], filters[1:] + self.input_block = DownsampleBlock(in_channels, filters[0], stride=1) + self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)] + self.bottleneck = DownsampleBlock(filters[-1], filters[-1]) + self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])] + self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))} def __call__(self, x): - intermediates = [x] - for e in self.encoders: intermediates.append(e(intermediates[-1])) - ret = intermediates[-1] - for d,i in zip(self.decoders, intermediates[:-1][::-1]): ret = d(ret.cat(i, dim=1)) - return ret - + x = self.input_block(x) + outputs = [x] + for downsample in self.downsample: + x = downsample(x) + outputs.append(x) + x = self.bottleneck(x) + for upsample, skip in zip(self.upsample, outputs[::-1]): + x = upsample(x, skip) + x = self.output["conv"](x) + return x + def load_from_pretrained(self): - fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt" - download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn) - state_dict = fake_torch_load(open(fn, "rb").read())['model_state_dict'] + fn = Path(__file__).parent.parent / "weights" / "unet-3d.ckpt" + download_file("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn) + state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict() for k, v in state_dict.items(): - print(k, v.shape) obj = get_child(self, k) assert obj.shape == v.shape, (k, obj.shape, v.shape) obj.assign(v.numpy()) + +if __name__ == "__main__": + mdl = UNet3D() + mdl.load_from_pretrained() diff --git a/test/test_nn.py b/test/test_nn.py index 5fe08080c3..040ed8044c 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3,7 +3,7 @@ import unittest import numpy as np from tinygrad.jit import TinyJit from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding +from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm import torch class TestNN(unittest.TestCase): @@ -168,6 +168,45 @@ class TestNN(unittest.TestCase): z = layer(x) torch_x = torch.tensor(x.cpu().numpy()) torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) + + def test_instancenorm_2d(self): + N, C, H, W = 20, 5, 10, 10 + + # create in tinygrad + layer = InstanceNorm(C) + + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() + torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + + # test + x = Tensor.randn(N, C, H, W) + z = layer(x) + torch_x = torch.tensor(x.cpu().numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + + def test_instancenorm_3d(self): + N, C, D, H, W = 20, 5, 3, 10, 10 + + # create in tinygrad + layer = InstanceNorm(C) + + # create in torch + with torch.no_grad(): + torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() + torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) + torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + + # test + x = Tensor.randn(N, C, D, H, W) + z = layer(x) + torch_x = torch.tensor(x.cpu().numpy()) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) def test_embedding(self): diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 8c2da44f74..cd67c27a47 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -78,6 +78,17 @@ class GroupNorm: # elementwise_affine on channels return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) +class InstanceNorm: + def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True): + self.num_features, self.eps = num_features, eps + self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None + self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None + + def __call__(self, x:Tensor): + x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape) + if self.weight is None or self.bias is None: return x + return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + class LayerNorm: def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True): self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)