Add MLPerf UNet3D model (#775)

* Add ResNet inference test and cannon

* Test with ResNet50

* test_car works with resnet fix

* Add KiTS19 dataset

* KiTS19: Implement iterate

* No batch load for this dataset

* Save results on iterate

* Implement dice score

* Add data prep and eval functions

* Resolve shape issue

* Conversion works but wrong values

* Segfaults when load_from_pretrained is called

* Fix segfault and assign properly

* Final result generated, though very slow

* Store and load final result to save time

* Fix typo in finalize

* Score computes

* More bug fixes, dice score is very low

* Working broken code

* Assign output values to result

* Getting a much higher score now

* Fix dataset preprocessing

* Mean DICE score of 88.5

* Ugh, typo

* Attempt to reimplement model

* Rename layers

* Tiny model works, kinda

* Accuracy? gone

* Implement InstanceNorm and match torch

* Test instance norm 2d and 3d

* Combined input block with downsample block

* Tiny model works, support strided convtranspose

* Commands to download dataset

* Clean up a bit

* unet3d_v2 -> unet3d

* Remove duplicated code

* Oops, put tests back
This commit is contained in:
Jacky Lee
2023-05-28 20:38:19 -07:00
committed by GitHub
parent 65d09031f2
commit 5d212864b5
9 changed files with 297 additions and 31 deletions

1
.gitignore vendored
View File

@@ -23,6 +23,7 @@ disassemblers/cuda_ioctl_sniffer
datasets/cifar-10-python.tar.gz
datasets/librispeech/
datasets/imagenet/
datasets/kits19/
datasets/squad/
datasets/img_align_celeba*
datasets/open-images-v6-mlperf

131
datasets/kits19.py Normal file
View File

@@ -0,0 +1,131 @@
import random
import functools
from pathlib import Path
import requests
import numpy as np
import nibabel as nib
from scipy import signal
import torch
import torch.nn.functional as F
from tinygrad.tensor import Tensor
BASEDIR = Path(__file__).parent.parent.resolve() / "datasets" / "kits19" / "data"
"""
To download the dataset:
```sh
git clone https://github.com/neheller/kits19
cd kits19
pip3 install -r requirements.txt
python3 -m starter_code.get_imaging
cd ..
mv kits datasets
```
"""
@functools.lru_cache(None)
def get_val_files():
data = requests.get("https://raw.githubusercontent.com/mlcommons/training/master/image_segmentation/pytorch/evaluation_cases.txt")
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.text.split("\n")])
def load_pair(file_path):
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
image_spacings = image.header["pixdim"][1:4].tolist()
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
return image, label, image_spacings
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
if image_spacings != target_spacing:
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
image = np.squeeze(image.numpy(), axis=0)
label = np.squeeze(label.numpy(), axis=0)
return image, label
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
image = np.clip(image, min_clip, max_clip)
image = (image - mean) / std
return image
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
current_shape = image.shape[1:]
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
image = np.pad(image, paddings, mode="edge")
label = np.pad(label, paddings, mode="edge")
return image, label
def preprocess(file_path):
image, label, image_spacings = load_pair(file_path)
image, label = resample3d(image, label, image_spacings)
image = normal_intensity(image.copy())
image, label = pad_to_min_shape(image, label)
return image, label
def iterate(val=True, shuffle=False):
if not val: raise NotImplementedError
files = get_val_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
for file in files:
X, Y = preprocess(file)
X = np.expand_dims(X, axis=0)
yield (X, Y)
def gaussian_kernel(n, std):
gaussian_1d = signal.gaussian(n, std)
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
gaussian_3d = gaussian_3d.reshape(n, n, n)
gaussian_3d = np.cbrt(gaussian_3d)
gaussian_3d /= gaussian_3d.max()
return gaussian_3d
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
from tinygrad.jit import TinyJit
mdl_run = TinyJit(lambda x: model(x).realize())
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
bounds = [image_shape[i] % strides[i] for i in range(dim)]
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
inputs = inputs[
...,
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
]
labels = labels[
...,
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
]
inputs, paddings = pad_input(inputs, roi_shape, strides)
padded_shape = inputs.shape[2:]
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
norm_patch = np.expand_dims(norm_patch, axis=0)
for i in range(0, strides[0] * size[0], strides[0]):
for j in range(0, strides[1] * size[1], strides[1]):
for k in range(0, strides[2] * size[2], strides[2]):
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k])).numpy()
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
result /= norm_map
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
return result, labels
if __name__ == "__main__":
for X, Y in iterate():
print(X.shape, Y.shape)

View File

@@ -1,5 +1,35 @@
from collections import OrderedDict
import unicodedata
import numpy as np
from scipy import signal
def gaussian_kernel(n, std):
gaussian_1d = signal.gaussian(n, std)
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
gaussian_3d = gaussian_3d.reshape(n, n, n)
gaussian_3d = np.cbrt(gaussian_3d)
gaussian_3d /= gaussian_3d.max()
return gaussian_3d
def prepare_arrays(image, roi_shape=(128, 128, 128)):
assert len(roi_shape) == 3 and any(roi_shape)
image_shape = list(image.shape[2:])
result = np.zeros((1, 3, *image_shape), dtype=image.dtype)
norm_map = np.zeros_like(result)
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]).astype(norm_map.dtype)
return result, norm_map, norm_patch
def get_slice(image, roi_shape=(128, 128, 128), overlap_factor=0.5):
assert len(roi_shape) == 3 and any(roi_shape)
assert 0 < overlap_factor < 1
image_shape, dim = list(image.shape[2:]), len(image.shape[2:])
strides = [int(roi_shape[i] * (1 - overlap_factor)) for i in range(dim)]
size = [(image_shape[i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
for i in range(0, strides[0] * size[0], strides[0]):
for j in range(0, strides[1] * size[1], strides[1]):
for k in range(0, strides[2] * size[2], strides[2]):
yield i, j, k
def _get_best_indices(logits, n_best_size):
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

View File

@@ -1,6 +1,7 @@
import re
import string
from collections import Counter
import numpy as np
def levenshtein(a, b):
n, m = len(a), len(b)
@@ -28,6 +29,22 @@ def word_error_rate(x, y):
scores += levenshtein(h_list, r_list)
return float(scores) / words, float(scores), words
def one_hot(arr, num_classes=3):
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
arr = res.reshape(list(arr.shape) + [num_classes])
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
return arr
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
prediction = prediction.argmax(axis=channel_axis)
prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
intersection = np.sum(prediction * target, axis=reduce_axis)
target_sum = np.sum(target, axis=reduce_axis)
prediction_sum = np.sum(prediction, axis=reduce_axis)
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
return result[0]
def normalize_string(s):
s = "".join(c for c in s.lower() if c not in string.punctuation)
s = re.sub(r'\b(a|an|the)\b', ' ', s)

View File

@@ -4,6 +4,7 @@ import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.helpers import getenv
from examples.mlperf import helpers
def eval_resnet():
# Resnet50-v1.5
@@ -41,6 +42,24 @@ def eval_resnet():
d += len(t)
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
st = time.perf_counter()
def eval_unet3d():
# UNet3D
from models.unet3d import UNet3D
from datasets.kits19 import iterate, sliding_window_inference
from examples.mlperf.metrics import get_dice_score
mdl = UNet3D()
mdl.load_from_pretrained()
s = 0
st = time.perf_counter()
for i, (image, label) in enumerate(iterate(), start=1):
mt = time.perf_counter()
pred, label = sliding_window_inference(mdl, image, label)
et = time.perf_counter()
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
s += get_dice_score(pred, label).mean()
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
st = time.perf_counter()
def eval_retinanet():
# RetinaNet with ResNeXt50_32X4D

View File

@@ -28,7 +28,8 @@ def spec_unet3d():
# 3D UNET
from models.unet3d import UNet3D
mdl = UNet3D()
img = Tensor.randn(1, 1, 5, 224, 224)
mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 128, 128, 128)
test_model(mdl, img)
def spec_rnnt():

View File

@@ -1,42 +1,59 @@
# https://github.com/wolny/pytorch-3dunet
from pathlib import Path
from extra.utils import download_file, fake_torch_load, get_child
import tinygrad.nn as nn
import torch
from tinygrad import nn
from tinygrad.tensor import Tensor
from extra.utils import download_file, get_child
class SingleConv:
def __init__(self, in_channels, out_channels):
self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group?
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False)
def __call__(self, x):
return self.conv(self.groupnorm(x)).relu()
class DownsampleBlock:
def __init__(self, c0, c1, stride=2):
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
class BasicModule:
def __init__(self, c0, c1, c2):
self.basic_module = {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)}
def __call__(self, x):
return self.basic_module['SingleConv2'](self.basic_module['SingleConv1'](x))
return x.sequential(self.conv1).sequential(self.conv2)
class UpsampleBlock:
def __init__(self, c0, c1):
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
def __call__(self, x, skip):
x = x.sequential(self.upsample_conv)
x = Tensor.cat(x, skip, dim=1)
return x.sequential(self.conv1).sequential(self.conv2)
class UNet3D:
def __init__(self):
ups = [16,32,64,128,256]
self.encoders = [BasicModule(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)]
self.decoders = [BasicModule(ups[-1-i] + ups[-2-i], ups[-2-i], ups[-2-i]) for i in range(3)]
self.final_conv = nn.Conv2d(32, 1, (1,1,1))
def __init__(self, in_channels=1, n_class=3):
filters = [32, 64, 128, 256, 320]
inp, out = filters[:-1], filters[1:]
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
def __call__(self, x):
intermediates = [x]
for e in self.encoders: intermediates.append(e(intermediates[-1]))
ret = intermediates[-1]
for d,i in zip(self.decoders, intermediates[:-1][::-1]): ret = d(ret.cat(i, dim=1))
return ret
x = self.input_block(x)
outputs = [x]
for downsample in self.downsample:
x = downsample(x)
outputs.append(x)
x = self.bottleneck(x)
for upsample, skip in zip(self.upsample, outputs[::-1]):
x = upsample(x, skip)
x = self.output["conv"](x)
return x
def load_from_pretrained(self):
fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt"
download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn)
state_dict = fake_torch_load(open(fn, "rb").read())['model_state_dict']
fn = Path(__file__).parent.parent / "weights" / "unet-3d.ckpt"
download_file("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
for k, v in state_dict.items():
print(k, v.shape)
obj = get_child(self, k)
assert obj.shape == v.shape, (k, obj.shape, v.shape)
obj.assign(v.numpy())
if __name__ == "__main__":
mdl = UNet3D()
mdl.load_from_pretrained()

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding
from tinygrad.nn import BatchNorm2d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
import torch
class TestNN(unittest.TestCase):
@@ -168,6 +168,45 @@ class TestNN(unittest.TestCase):
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
def test_instancenorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = InstanceNorm(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_instancenorm_3d(self):
N, C, D, H, W = 20, 5, 3, 10, 10
# create in tinygrad
layer = InstanceNorm(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, D, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_embedding(self):

View File

@@ -78,6 +78,17 @@ class GroupNorm:
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
class InstanceNorm:
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
self.num_features, self.eps = num_features, eps
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
def __call__(self, x:Tensor):
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
if self.weight is None or self.bias is None: return x
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)