mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
more working (#8550)
This commit is contained in:
@@ -3,14 +3,12 @@ from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
from extra.onnx import get_run_onnx
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(sys.argv[1])
|
||||
print(onnx_file)
|
||||
def load_onnx_model(fn):
|
||||
onnx_file = fetch(fn)
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
print("loaded model")
|
||||
|
||||
# find preinitted tensors and ignore them
|
||||
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}
|
||||
@@ -20,6 +18,11 @@ if __name__ == "__main__":
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in expected_inputs}
|
||||
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in expected_inputs}
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
|
||||
return run_onnx_jit, input_shapes, input_types
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
|
||||
print("loaded model")
|
||||
|
||||
for i in range(3):
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
|
||||
41
examples/test_onnx_imagenet.py
Normal file
41
examples/test_onnx_imagenet.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import random, sys
|
||||
import numpy as np
|
||||
from extra.datasets.imagenet import get_imagenet_categories, get_val_files, center_crop
|
||||
from examples.benchmark_onnx import load_onnx_model
|
||||
from PIL import Image
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
# works:
|
||||
# ~70% - https://github.com/onnx/models/raw/refs/heads/main/validated/vision/classification/resnet/model/resnet50-v2-7.onnx
|
||||
# ~43% - https://github.com/onnx/models/raw/refs/heads/main/Computer_Vision/alexnet_Opset16_torch_hub/alexnet_Opset16.onnx
|
||||
# ~64% - https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx
|
||||
# broken:
|
||||
# https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
# https://huggingface.co/qualcomm/MobileNet-v2-Quantized/resolve/main/MobileNet-v2-Quantized.onnx
|
||||
# https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7-quantized.onnx
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
|
||||
t_name, shape = list(input_shapes.items())[0]
|
||||
assert shape[1:] == (3,224,224), f"shape is {shape}"
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
|
||||
files = get_val_files()
|
||||
random.shuffle(files)
|
||||
cir = get_imagenet_categories()
|
||||
hit = 0
|
||||
for i,fn in enumerate(files):
|
||||
img = Image.open(fn)
|
||||
img = img.convert('RGB') if img.mode != "RGB" else img
|
||||
img = center_crop(img)
|
||||
img = np.array(img)
|
||||
img = Tensor(img).permute(2,0,1).reshape(1,3,224,224)
|
||||
img = ((img.cast(dtypes.float32)/255.0) - input_mean) / input_std
|
||||
y = cir[fn.split("/")[-2]]
|
||||
p = run_onnx_jit(**{t_name:img})
|
||||
assert p.shape == (1,1000)
|
||||
t = p.argmax().item()
|
||||
hit += y==t
|
||||
print(f"target: {y:3d} pred: {t:3d} acc: {hit/(i+1)*100:.2f}%")
|
||||
Reference in New Issue
Block a user