From 85d17a2acd1f0687b24d9b3bab2a3279b2e47438 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Jun 2022 13:17:15 -0700 Subject: [PATCH] running resnet onnx --- test/test_efficientnet.py | 16 ++++++----- test/test_onnx.py | 60 ++++++++++++++++++++++++++++++++------- test/test_ops.py | 3 ++ tinygrad/nn.py | 12 ++++---- tinygrad/tensor.py | 9 +++++- 5 files changed, 75 insertions(+), 25 deletions(-) diff --git a/test/test_efficientnet.py b/test/test_efficientnet.py index 74139cd658..8294ed39da 100644 --- a/test/test_efficientnet.py +++ b/test/test_efficientnet.py @@ -17,7 +17,7 @@ def _load_labels(): _LABELS = _load_labels() -def _infer(model: EfficientNet, img, bs=1): +def preprocess(img): # preprocess image aspect_ratio = img.size[0] / img.size[1] img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) @@ -29,17 +29,19 @@ def _infer(model: EfficientNet, img, bs=1): # low level preprocess img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) - img /= 255.0 - img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) - img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) + #img /= 255.0 + #img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) + #img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) + return img +def _infer(model: EfficientNet, img, bs=1): # run the net if bs > 1: img = img.repeat(bs, axis=0) out = model.forward(Tensor(img)).cpu() return _LABELS[np.argmax(out.data[0])] -chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg') -car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') +chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')) +car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')) class TestEfficientNet(unittest.TestCase): @classmethod @@ -56,7 +58,7 @@ class TestEfficientNet(unittest.TestCase): self.assertEqual(label, "hen") def test_chicken_bigbatch(self): - label = _infer(self.model, chicken_img, 16) + label = _infer(self.model, chicken_img, 4) self.assertEqual(label, "hen") def test_car(self): diff --git a/test/test_onnx.py b/test/test_onnx.py index 89415d8ce5..668de1551b 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -6,10 +6,20 @@ import onnx from extra.utils import fetch from tinygrad.tensor import Tensor -def run_onnx(onnx_model, inputs={}): +def run_onnx(onnx_model, inputs={}, debug=False): def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim) + def buffer_parse(inp): + if inp.data_type == 1: + ret = Tensor(np.frombuffer(inp.raw_data, dtype=np.float32).reshape(inp.dims).copy()) + elif inp.data_type == 7: + ret = Tensor(np.frombuffer(inp.raw_data, dtype=np.int64).reshape(inp.dims).astype(np.float32).copy()) + else: + raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}") + return ret + def attribute_parse(a): if a.type == 7: return tuple([int(x) for x in a.ints]) + elif a.type == 4: return buffer_parse(a.t) # TENSOR elif a.type == 2: return int(a.i) elif a.type == 1: return float(a.f) else: raise Exception(f"can't parse {a.type} {a}") @@ -17,9 +27,19 @@ def run_onnx(onnx_model, inputs={}): tensors = {} + # get weights and biases + for inp in onnx_model.graph.initializer: + #print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) + if len(inp.raw_data) == 0: + tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims)) + else: + tensors[inp.name] = buffer_parse(inp) + # get inputs for inp in onnx_model.graph.input: + if inp.name in tensors: continue shape = shape_to_tuple(inp.type.tensor_type.shape) + if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size if inp.name in inputs: input_shape = inputs[inp.name].shape assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" @@ -28,18 +48,13 @@ def run_onnx(onnx_model, inputs={}): print(f"filling {inp.name} shape {shape} with 0") tensors[inp.name] = Tensor.zeros(*shape) - # get weights and biases - for inp in onnx_model.graph.initializer: - assert inp.data_type == 1 - #print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) - tensors[inp.name] = Tensor(np.frombuffer(inp.raw_data, dtype=np.float32).reshape(inp.dims).copy()) for num,n in enumerate(onnx_model.graph.node): - #print(f"{num}: op {n.op_type}") + if debug: print(f"{num}: op {n.op_type}") inp = [tensors[x] for x in n.input] opt = attribute_to_dict(n.attribute) if n.op_type == "Conv": - x,w,b = inp + x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None) assert opt['dilations'] == (1,1) ret = x.pad2d(opt['pads']).conv2d(w, b, stride=opt['strides'], groups=opt['group']) elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) @@ -49,8 +64,12 @@ def run_onnx(onnx_model, inputs={}): elif n.op_type == "Add": ret = inp[0] + inp[1] elif n.op_type == "Sub": ret = inp[0] - inp[1] elif n.op_type == "Mul": ret = inp[0] * inp[1] - elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis']) + elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0) elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis']) + elif n.op_type == "Clip": + if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max']) + else: ret = inp[0].clip(inp[1], inp[2]) + elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True) elif n.op_type == "Split": i = 0 arg = [(0,x) for x in inp[0].shape] @@ -64,9 +83,18 @@ def run_onnx(onnx_model, inputs={}): #print(a.shape, w.shape, b.shape) if opt['transB'] == 1: w = w.transpose() ret = a.linear(w,b) + elif n.op_type == "BatchNormalization": + from tinygrad.nn import batch_normalize + ret = batch_normalize(inp[0], inp[3], inp[4], inp[1], inp[2], opt['epsilon']) + elif n.op_type == "MaxPool": + ret = inp[0].pad2d(opt['pads']) + ret = ret.max_pool2d(opt['kernel_shape']) + chan = ret.shape[1] + # strides aren't supported in max_pool + w = Tensor.eye(chan).reshape((chan, chan, 1, 1)) + ret = ret.conv2d(w, stride=opt['strides']) else: - print(n.op_type, n.input, n.output) - print(n) + print("UNSUPPORTED", n.op_type, n.input, n.output) raise Exception(f"op_type {n.op_type} not supported") assert len(n.output) == 1 tensors[n.output[0]] = ret @@ -97,6 +125,16 @@ class TestOnnxModel(unittest.TestCase): torch_out = run_onnx_torch(onnx_model, inputs).numpy() print(tinygrad_out, torch_out) np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2) + def test_resnet(self): + # mobilenet requires "Shape", "Gather", "Unsqueeze" + # googlenet doesn't work without dilated convs + dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx") + onnx_model = onnx.load(io.BytesIO(dat)) + from test.test_efficientnet import chicken_img, car_img, _LABELS + inputs = {"data": chicken_img} + tinygrad_out = run_onnx(onnx_model, inputs)['resnetv15_dense0_fwd'].numpy() + cls = tinygrad_out.argmax() + print(cls, _LABELS[cls]) if __name__ == "__main__": unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index b4cddcfb3f..d5e65c1caf 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -266,6 +266,9 @@ class TestOps(unittest.TestCase): for dim in range(-1, 2): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) + def test_clip(self): + helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/tinygrad/nn.py b/tinygrad/nn.py index f251fe76df..97bf9d7d2e 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -1,6 +1,10 @@ from tinygrad.tensor import Tensor import numpy as np +def batch_normalize(x, mean, var, weight, bias, eps): + x = (x - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1]) + return x.mul(var.add(eps).reshape(shape=[1, -1, 1, 1])**-0.5) + bias.reshape(shape=[1, -1, 1, 1]) + class BatchNorm2D: def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): assert affine == True, "BatchNorm2D is only supported with affine" @@ -24,13 +28,9 @@ class BatchNorm2D: if self.num_batches_tracked is None: self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) self.num_batches_tracked += 1 - return self.normalize(x, batch_mean, batch_var) + return batch_normalize(x, batch_mean, batch_var, self.weight, self.bias, self.eps) - return self.normalize(x, self.running_mean, self.running_var) - - def normalize(self, x, mean, var): - x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1]) - return x.mul(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**-0.5) + self.bias.reshape(shape=[1, -1, 1, 1]) + return batch_normalize(x, self.running_mean, self.running_var, self.weight, self.bias, self.eps) class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f860a71cf4..181a41bb85 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -257,7 +257,8 @@ class Tensor: bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1]) - order = tuple(list(range(len(x.shape)-2))+[len(x.shape)-1, len(x.shape)-2]) + if len(x.shape) == 1: order, out_shape_t = (0,), (cout, ) + else: order = tuple(list(range(len(x.shape)-2))+[len(x.shape)-1, len(x.shape)-2]) worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2]) # NOTE: with NHWC we can remove the transposes @@ -317,6 +318,9 @@ class Tensor: def relu6(self): return self.relu() - (self-6).relu() + def clip(self, min, max): + return ((self-min).relu()+min) - (self-max).relu() + def hardswish(self): return self * (self+3).relu6() * (1/6) @@ -380,6 +384,9 @@ class Tensor: # ***** functional nn ops ***** + def reshape(self, shape): + return self._reshape(shape=shape) + def linear(self, weight, bias): shp = [1] * (len(self.shape)-1) + [-1] ret = self.mul(weight.reshape(shape=shp)) if len(weight.shape) == 1 else self.dot(weight)