more working (#8550)

This commit is contained in:
George Hotz
2025-01-09 18:40:08 -08:00
committed by GitHub
parent 2cbb34535c
commit e172b759f0
2 changed files with 48 additions and 4 deletions

View File

@@ -3,14 +3,12 @@ from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad.tensor import _from_np_dtype
from extra.onnx import get_run_onnx
if __name__ == "__main__":
onnx_file = fetch(sys.argv[1])
print(onnx_file)
def load_onnx_model(fn):
onnx_file = fetch(fn)
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
run_onnx = get_run_onnx(onnx_model)
print("loaded model")
# find preinitted tensors and ignore them
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}
@@ -20,6 +18,11 @@ if __name__ == "__main__":
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in expected_inputs}
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in expected_inputs}
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, input_shapes, input_types
if __name__ == "__main__":
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
print("loaded model")
for i in range(3):
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}

View File

@@ -0,0 +1,41 @@
import random, sys
import numpy as np
from extra.datasets.imagenet import get_imagenet_categories, get_val_files, center_crop
from examples.benchmark_onnx import load_onnx_model
from PIL import Image
from tinygrad import Tensor, dtypes
# works:
# ~70% - https://github.com/onnx/models/raw/refs/heads/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx
# ~43% - https://github.com/onnx/models/raw/refs/heads/main/Computer_Vision/alexnet_Opset16_torch_hub/alexnet_Opset16.onnx
# ~64% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
# broken:
# https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7-quantized.onnx
if __name__ == "__main__":
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
t_name, shape = list(input_shapes.items())[0]
assert shape[1:] == (3,224,224), f"shape is {shape}"
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
files = get_val_files()
random.shuffle(files)
cir = get_imagenet_categories()
hit = 0
for i,fn in enumerate(files):
img = Image.open(fn)
img = img.convert('RGB') if img.mode != "RGB" else img
img = center_crop(img)
img = np.array(img)
img = Tensor(img).permute(2,0,1).reshape(1,3,224,224)
img = ((img.cast(dtypes.float32)/255.0) - input_mean) / input_std
y = cir[fn.split("/")[-2]]
p = run_onnx_jit(**{t_name:img})
assert p.shape == (1,1000)
t = p.argmax().item()
hit += y==t
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")