fix commavq benchmark (#4712)

* fix _slice and assert explicit device

* with _slice
This commit is contained in:
qazal
2024-05-25 00:40:57 +08:00
committed by GitHub
parent 84255069e7
commit c170ddceaf
2 changed files with 8 additions and 7 deletions

View File

@@ -1,5 +1,6 @@
import csv, pathlib, time, numpy as np
from os import getenv
from tinygrad.device import CompileError
import torch
torch.set_num_threads(1)
import onnx
@@ -60,8 +61,8 @@ def benchmark_model(m, devices, validate_outs=False):
# print input names
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
try:
for device in devices:
for device in devices:
try:
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
@@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
del inputs, tinygrad_model, tinygrad_jitted_model
except Exception as e:
# model crashed
print(f"{m} crashed on {device} with: {e}")
return
except CompileError as e:
# METAL fails with buffer count limit
if m == "dm" and device == "METAL": return
raise e
# convert model to torch
try: