From 46d419060b43dee7546800923b46dab3c0ac5c51 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 10 May 2023 16:30:49 -0700 Subject: [PATCH] start on mlperf models --- datasets/__init__.py | 2 +- examples/mlperf/README | 19 +++++++++++++++++++ examples/mlperf/model_spec.py | 33 +++++++++++++++++++++++++++++++++ examples/stable_diffusion.py | 3 +-- extra/utils.py | 2 +- models/resnet.py | 3 ++- models/unet3d.py | 31 +++++++++++++++++++++++++++++++ 7 files changed, 88 insertions(+), 5 deletions(-) create mode 100644 examples/mlperf/README create mode 100644 examples/mlperf/model_spec.py create mode 100644 models/unet3d.py diff --git a/datasets/__init__.py b/datasets/__init__.py index 592b6a738b..019820874f 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -17,7 +17,7 @@ def fetch_cifar(train=True): cifar10_mean = np.array([0.4913997551666284, 0.48215855929893703, 0.4465309133731618], dtype=np.float32).reshape(1,3,1,1) cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1) fn = os.path.dirname(__file__)+"/cifar-10-python.tar.gz" - download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn, True) + download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn) tt = tarfile.open(fn, mode='r:gz') if train: db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)] diff --git a/examples/mlperf/README b/examples/mlperf/README new file mode 100644 index 0000000000..d971e0c57a --- /dev/null +++ b/examples/mlperf/README @@ -0,0 +1,19 @@ +Each model should be a clean single file. +They are imported from the top level `models` directory + +It should be capable of loading weights from the reference imp. + +We will focus on these 5 models: + +# Resnet50-v1.5 (classic) -- 8.2 GOPS/input +# Retinanet +# 3D UNET (upconvs) +# RNNT +# BERT-large (transformer) + +They are used in both the training and inference benchmark: +https://mlcommons.org/en/training-normal-21/ +https://mlcommons.org/en/inference-edge-30/ +And we will submit to both. + +NOTE: we are Edge since we don't have ECC RAM diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py new file mode 100644 index 0000000000..7c9b839902 --- /dev/null +++ b/examples/mlperf/model_spec.py @@ -0,0 +1,33 @@ +# load each model here, quick benchmark +from tinygrad.tensor import Tensor +from tinygrad.helpers import GlobalCounters + +def test_model(model, *inputs): + GlobalCounters.reset() + model(*inputs).numpy() + # TODO: return event future to still get the time_sum_s without DEBUG=2 + print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms") + +if __name__ == "__main__": + # inference only for now + Tensor.training = False + Tensor.no_grad = True + + # Resnet50-v1.5 + """ + from models.resnet import ResNet50 + mdl = ResNet50() + img = Tensor.randn(1, 3, 224, 224) + test_model(mdl, img) + """ + + # Retinanet + + # 3D UNET + from models.unet3d import UNet3D + mdl = UNet3D() + mdl.load_from_pretrained() + + # RNNT + + # BERT-large diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index d7fed82a33..40cfab57cd 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -615,8 +615,7 @@ if __name__ == "__main__": # load in weights download_file( 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', - FILENAME, - skip_if_exists=True + FILENAME ) dat = fake_torch_load_zipped(open(FILENAME, "rb")) for k,v in dat['state_dict'].items(): diff --git a/extra/utils.py b/extra/utils.py index 04d4e31e47..535bf904b7 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -20,7 +20,7 @@ def fetch(url): with open(fp, "rb") as f: return f.read() -def download_file(url, fp, skip_if_exists=False): +def download_file(url, fp, skip_if_exists=True): import requests, os if skip_if_exists and os.path.isfile(fp) and os.stat(fp).st_size > 0: return diff --git a/models/resnet.py b/models/resnet.py index 0455331a64..d55574a0dc 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -80,6 +80,7 @@ class ResNet: self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2) self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2) self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2) + # TODO: replace with nn.Linear self.fc = {"weight": Tensor.scaled_uniform(512 * self.block.expansion, num_classes), "bias": Tensor.zeros(num_classes)} def _make_layer(self, block, planes, num_blocks, stride): @@ -105,7 +106,7 @@ class ResNet: def load_from_pretrained(self): # TODO replace with fake torch load - + model_urls = { 18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', diff --git a/models/unet3d.py b/models/unet3d.py new file mode 100644 index 0000000000..ff6935aab6 --- /dev/null +++ b/models/unet3d.py @@ -0,0 +1,31 @@ +# https://github.com/wolny/pytorch-3dunet +from pathlib import Path +from extra.utils import download_file, fake_torch_load +import tinygrad.nn as nn + +class SingleConv: + def __init__(self, in_channels, out_channels): + self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group? + self.conv = nn.Conv2d(in_channels, out_channels, (3,3,3), bias=False) + def __call__(self, x): + return self.conv(self.groupnorm(x)).relu() + +def get_basic_module(c0, c1, c2): return {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)} + +class UNet3D: + def __init__(self): + ups = [16,32,64,128,256] + self.encoders = [get_basic_module(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)] + self.decoders = [get_basic_module(ups[-1-i] + ups[-2+i], ups[-2+i], ups[-2+i]) for i in range(3)] + self.final_conv = nn.Conv2d(32, 1, (1,1,1)) + + def __call__(self, x): + # TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3) + pass + + def load_from_pretrained(self): + fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt" + download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn) + state = fake_torch_load(open(fn, "rb").read())['model_state_dict'] + for x in state.keys(): + print(x, state[x].shape)