mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add compile3 benchmark [pr] (#8929)
This commit is contained in:
@@ -98,26 +98,37 @@ def test_vs_compile(run, new_inputs, test_val=None):
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
|
||||
return val
|
||||
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file):
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file, ort=False):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
if getenv("ORT"):
|
||||
timings = []
|
||||
if ort:
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_file)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
|
||||
for _ in range(1 if test_val is not None else 5):
|
||||
st = time.perf_counter()
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
|
||||
timings.append(time.perf_counter() - st)
|
||||
new_torch_out = onnx_output[0]
|
||||
print("got ort outputs")
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
# NOTE: we have to correct the order here
|
||||
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
|
||||
print("got torch outputs")
|
||||
import torch
|
||||
from onnx2torch import convert
|
||||
inputs = {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}
|
||||
torch_model = convert(onnx_model).float()
|
||||
with torch.no_grad():
|
||||
for _ in range(1 if test_val is not None else 5):
|
||||
st = time.perf_counter()
|
||||
torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()])
|
||||
timings.append(time.perf_counter() - st)
|
||||
new_torch_out = torch_out.numpy()
|
||||
|
||||
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
|
||||
print("test vs onnx passed")
|
||||
if test_val is not None:
|
||||
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
|
||||
print("test vs onnx passed")
|
||||
return timings
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
@@ -131,4 +142,12 @@ if __name__ == "__main__":
|
||||
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
|
||||
|
||||
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
|
||||
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)
|
||||
if getenv("BENCHMARK"):
|
||||
for be in ["torch", "ort"]:
|
||||
try:
|
||||
timings = test_vs_onnx(new_inputs, None, onnx_file, be=="ort")
|
||||
print(f"timing {be}: {min(timings)*1000:.2f} ms")
|
||||
except Exception as e:
|
||||
print(f"{be} fail with {e}")
|
||||
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file, getenv("ORT"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user