mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
running resnet onnx
This commit is contained in:
@@ -17,7 +17,7 @@ def _load_labels():
|
||||
|
||||
_LABELS = _load_labels()
|
||||
|
||||
def _infer(model: EfficientNet, img, bs=1):
|
||||
def preprocess(img):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
@@ -29,17 +29,19 @@ def _infer(model: EfficientNet, img, bs=1):
|
||||
# low level preprocess
|
||||
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
|
||||
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
|
||||
img /= 255.0
|
||||
img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
|
||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
|
||||
#img /= 255.0
|
||||
#img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
|
||||
#img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
|
||||
return img
|
||||
|
||||
def _infer(model: EfficientNet, img, bs=1):
|
||||
# run the net
|
||||
if bs > 1: img = img.repeat(bs, axis=0)
|
||||
out = model.forward(Tensor(img)).cpu()
|
||||
return _LABELS[np.argmax(out.data[0])]
|
||||
|
||||
chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
|
||||
car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
|
||||
chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg'))
|
||||
car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg'))
|
||||
|
||||
class TestEfficientNet(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -56,7 +58,7 @@ class TestEfficientNet(unittest.TestCase):
|
||||
self.assertEqual(label, "hen")
|
||||
|
||||
def test_chicken_bigbatch(self):
|
||||
label = _infer(self.model, chicken_img, 16)
|
||||
label = _infer(self.model, chicken_img, 4)
|
||||
self.assertEqual(label, "hen")
|
||||
|
||||
def test_car(self):
|
||||
|
||||
@@ -6,10 +6,20 @@ import onnx
|
||||
from extra.utils import fetch
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def run_onnx(onnx_model, inputs={}):
|
||||
def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
def buffer_parse(inp):
|
||||
if inp.data_type == 1:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=np.float32).reshape(inp.dims).copy())
|
||||
elif inp.data_type == 7:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=np.int64).reshape(inp.dims).astype(np.float32).copy())
|
||||
else:
|
||||
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
|
||||
return ret
|
||||
|
||||
def attribute_parse(a):
|
||||
if a.type == 7: return tuple([int(x) for x in a.ints])
|
||||
elif a.type == 4: return buffer_parse(a.t) # TENSOR
|
||||
elif a.type == 2: return int(a.i)
|
||||
elif a.type == 1: return float(a.f)
|
||||
else: raise Exception(f"can't parse {a.type} {a}")
|
||||
@@ -17,9 +27,19 @@ def run_onnx(onnx_model, inputs={}):
|
||||
|
||||
tensors = {}
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
#print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
if len(inp.raw_data) == 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims))
|
||||
else:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = shape_to_tuple(inp.type.tensor_type.shape)
|
||||
if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
|
||||
if inp.name in inputs:
|
||||
input_shape = inputs[inp.name].shape
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
@@ -28,18 +48,13 @@ def run_onnx(onnx_model, inputs={}):
|
||||
print(f"filling {inp.name} shape {shape} with 0")
|
||||
tensors[inp.name] = Tensor.zeros(*shape)
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
assert inp.data_type == 1
|
||||
#print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
tensors[inp.name] = Tensor(np.frombuffer(inp.raw_data, dtype=np.float32).reshape(inp.dims).copy())
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
#print(f"{num}: op {n.op_type}")
|
||||
if debug: print(f"{num}: op {n.op_type}")
|
||||
inp = [tensors[x] for x in n.input]
|
||||
opt = attribute_to_dict(n.attribute)
|
||||
if n.op_type == "Conv":
|
||||
x,w,b = inp
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert opt['dilations'] == (1,1)
|
||||
ret = x.pad2d(opt['pads']).conv2d(w, b, stride=opt['strides'], groups=opt['group'])
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
@@ -49,8 +64,12 @@ def run_onnx(onnx_model, inputs={}):
|
||||
elif n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
elif n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
elif n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'])
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Clip":
|
||||
if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max'])
|
||||
else: ret = inp[0].clip(inp[1], inp[2])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "Split":
|
||||
i = 0
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
@@ -64,9 +83,18 @@ def run_onnx(onnx_model, inputs={}):
|
||||
#print(a.shape, w.shape, b.shape)
|
||||
if opt['transB'] == 1: w = w.transpose()
|
||||
ret = a.linear(w,b)
|
||||
elif n.op_type == "BatchNormalization":
|
||||
from tinygrad.nn import batch_normalize
|
||||
ret = batch_normalize(inp[0], inp[3], inp[4], inp[1], inp[2], opt['epsilon'])
|
||||
elif n.op_type == "MaxPool":
|
||||
ret = inp[0].pad2d(opt['pads'])
|
||||
ret = ret.max_pool2d(opt['kernel_shape'])
|
||||
chan = ret.shape[1]
|
||||
# strides aren't supported in max_pool
|
||||
w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
|
||||
ret = ret.conv2d(w, stride=opt['strides'])
|
||||
else:
|
||||
print(n.op_type, n.input, n.output)
|
||||
print(n)
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
assert len(n.output) == 1
|
||||
tensors[n.output[0]] = ret
|
||||
@@ -97,6 +125,16 @@ class TestOnnxModel(unittest.TestCase):
|
||||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
def test_resnet(self):
|
||||
# mobilenet requires "Shape", "Gather", "Unsqueeze"
|
||||
# googlenet doesn't work without dilated convs
|
||||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx")
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
from test.test_efficientnet import chicken_img, car_img, _LABELS
|
||||
inputs = {"data": chicken_img}
|
||||
tinygrad_out = run_onnx(onnx_model, inputs)['resnetv15_dense0_fwd'].numpy()
|
||||
cls = tinygrad_out.argmax()
|
||||
print(cls, _LABELS[cls])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -266,6 +266,9 @@ class TestOps(unittest.TestCase):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
|
||||
|
||||
def test_clip(self):
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
def batch_normalize(x, mean, var, weight, bias, eps):
|
||||
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.mul(var.add(eps).reshape(shape=[1, -1, 1, 1])**-0.5) + bias.reshape(shape=[1, -1, 1, 1])
|
||||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
assert affine == True, "BatchNorm2D is only supported with affine"
|
||||
@@ -24,13 +28,9 @@ class BatchNorm2D:
|
||||
if self.num_batches_tracked is None: self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
self.num_batches_tracked += 1
|
||||
|
||||
return self.normalize(x, batch_mean, batch_var)
|
||||
return batch_normalize(x, batch_mean, batch_var, self.weight, self.bias, self.eps)
|
||||
|
||||
return self.normalize(x, self.running_mean, self.running_var)
|
||||
|
||||
def normalize(self, x, mean, var):
|
||||
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.mul(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**-0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
|
||||
return batch_normalize(x, self.running_mean, self.running_var, self.weight, self.bias, self.eps)
|
||||
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
|
||||
|
||||
@@ -257,7 +257,8 @@ class Tensor:
|
||||
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = tuple(list(x.shape[0:-2])+[cout,-1])
|
||||
order = tuple(list(range(len(x.shape)-2))+[len(x.shape)-1, len(x.shape)-2])
|
||||
if len(x.shape) == 1: order, out_shape_t = (0,), (cout, )
|
||||
else: order = tuple(list(range(len(x.shape)-2))+[len(x.shape)-1, len(x.shape)-2])
|
||||
worder = tuple(list(range(len(w.shape)-2))+[len(w.shape)-1, len(w.shape)-2])
|
||||
|
||||
# NOTE: with NHWC we can remove the transposes
|
||||
@@ -317,6 +318,9 @@ class Tensor:
|
||||
def relu6(self):
|
||||
return self.relu() - (self-6).relu()
|
||||
|
||||
def clip(self, min, max):
|
||||
return ((self-min).relu()+min) - (self-max).relu()
|
||||
|
||||
def hardswish(self):
|
||||
return self * (self+3).relu6() * (1/6)
|
||||
|
||||
@@ -380,6 +384,9 @@ class Tensor:
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def reshape(self, shape):
|
||||
return self._reshape(shape=shape)
|
||||
|
||||
def linear(self, weight, bias):
|
||||
shp = [1] * (len(self.shape)-1) + [-1]
|
||||
ret = self.mul(weight.reshape(shape=shp)) if len(weight.shape) == 1 else self.dot(weight)
|
||||
|
||||
Reference in New Issue
Block a user