fix try except not catching fxn() in benchmark (#1783)

* have function raise notimplementederror

* more lines

* revert back to 2 lines :D

* aahhhhhhhh shoooot im stupid

* keep it minimal?
This commit is contained in:
geohotstan
2023-09-06 22:36:43 +08:00
committed by GitHub
parent 09e78a9d07
commit 1bbf26d7fd

View File

@@ -14,8 +14,7 @@ from tinygrad.ops import Device
MODELS = {
"resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",
# broken in torch CPU
#"openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
"openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
"efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
"shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
"commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx",
@@ -85,8 +84,7 @@ def benchmark_model(m, validate_outs=False):
torch_mps_model = torch_model.to(torch_device)
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
except NotImplementedError:
print(f"{m:16s}onnx2torch doesn't support this model")
except Exception as e: print(f"{m:16s}onnx2torch {type(e).__name__:>25}")
# bench onnxruntime
ort_options = ort.SessionOptions()