mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix try except not catching fxn() in benchmark (#1783)
* have function raise notimplementederror * more lines * revert back to 2 lines :D * aahhhhhhhh shoooot im stupid * keep it minimal?
This commit is contained in:
6
test/external/external_model_benchmark.py
vendored
6
test/external/external_model_benchmark.py
vendored
@@ -14,8 +14,7 @@ from tinygrad.ops import Device
|
||||
|
||||
MODELS = {
|
||||
"resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",
|
||||
# broken in torch CPU
|
||||
#"openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
|
||||
"openpilot": "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx",
|
||||
"efficientnet": "https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx",
|
||||
"shufflenet": "https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx",
|
||||
"commavq": "https://huggingface.co/commaai/commavq-gpt2m/resolve/main/gpt2m.onnx",
|
||||
@@ -85,8 +84,7 @@ def benchmark_model(m, validate_outs=False):
|
||||
torch_mps_model = torch_model.to(torch_device)
|
||||
torch_mps_inputs = [x.to(torch_device) for x in torch_inputs]
|
||||
benchmark(m, f"torch_{torch_device}", lambda: torch_mps_model(*torch_mps_inputs))
|
||||
except NotImplementedError:
|
||||
print(f"{m:16s}onnx2torch doesn't support this model")
|
||||
except Exception as e: print(f"{m:16s}onnx2torch {type(e).__name__:>25}")
|
||||
|
||||
# bench onnxruntime
|
||||
ort_options = ort.SessionOptions()
|
||||
|
||||
Reference in New Issue
Block a user