add onnx_helpers to extra and add ort validate to benchmark_onnx (#8890)

* start

* log severity

* only change this

* change abstraction so it's more usable for huggingface

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2025-02-05 05:36:01 +08:00
committed by GitHub
parent 89eebd4bfb
commit 057c70b05f
4 changed files with 50 additions and 28 deletions

View File

@@ -1,41 +1,28 @@
import sys, onnx, time
from tinygrad import Tensor, TinyJit, Device, GlobalCounters, fetch
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
from extra.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(fn):
onnx_file = fetch(fn)
def load_onnx_model(onnx_file):
onnx_model = onnx.load(onnx_file)
run_onnx = OnnxRunner(onnx_model)
# find preinitted tensors and ignore them
initted_tensors = {inp.name:None for inp in onnx_model.graph.initializer}
expected_inputs = [inp for inp in onnx_model.graph.input if inp.name not in initted_tensors]
# get real inputs
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in expected_inputs}
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in expected_inputs}
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())), prune=True)
return run_onnx_jit, input_shapes, input_types
def get_new_inputs(input_shapes):
#from tinygrad.tensor import _from_np_dtype
#return {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
import numpy as np
return {k:Tensor(np.random.uniform(size=shp).astype(input_types[k]) * 8).realize() for k,shp in sorted(input_shapes.items())}
return run_onnx_jit, run_onnx.graph_inputs
if __name__ == "__main__":
run_onnx_jit, input_shapes, input_types = load_onnx_model(sys.argv[1])
onnx_file = fetch(sys.argv[1])
run_onnx_jit, input_specs = load_onnx_model(onnx_file)
print("loaded model")
for i in range(3):
new_inputs = get_new_inputs(input_shapes)
new_inputs = get_example_inputs(input_specs)
GlobalCounters.reset()
print(f"run {i}")
run_onnx_jit(**new_inputs)
# run 20 times
for _ in range(20):
new_inputs = get_new_inputs(input_shapes)
new_inputs = get_example_inputs(input_specs)
GlobalCounters.reset()
st = time.perf_counter()
out = run_onnx_jit(**new_inputs)
@@ -43,3 +30,7 @@ if __name__ == "__main__":
val = out.numpy()
et = time.perf_counter()
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
if getenv("ORT"):
validate(onnx_file, new_inputs, rtol=1e-3, atol=1e-3)
print("model validated")

View File

@@ -60,9 +60,9 @@ if __name__ == "__main__":
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": True})
run_onnx_jit, input_shapes, input_types = load_onnx_model(fn)
t_name, shape = list(input_shapes.items())[0]
assert shape[1:] == (3,224,224), f"shape is {shape}"
run_onnx_jit, input_specs = load_onnx_model(fn)
t_name, t_spec = list(input_specs.items())[0]
assert t_spec.shape[1:] == (3,224,224), f"shape is {t_spec.shape}"
hit = 0
for i,(img,y) in enumerate(imagenet_dataloader(cnt=100)):

View File

@@ -4,6 +4,7 @@ from ultralytics import YOLO
import onnx
from pathlib import Path
from extra.onnx import OnnxRunner
from extra.onnx_helpers import get_example_inputs
from tinygrad.tensor import Tensor
os.chdir("/tmp")
@@ -11,8 +12,5 @@ if not Path("yolov8n-seg.onnx").is_file():
model = YOLO("yolov8n-seg.pt")
model.export(format="onnx", imgsz=[480,640])
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
# TODO: move get example inputs to onnx
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
print(input_shapes)
run_onnx = OnnxRunner(onnx_model)
run_onnx({"images": Tensor.zeros(1,3,480,640)}, debug=True)
run_onnx(get_example_inputs(run_onnx.graph_inputs), debug=True)

33
extra/onnx_helpers.py Normal file
View File

@@ -0,0 +1,33 @@
from tinygrad import Tensor
from tinygrad.tensor import _to_np_dtype
from extra.onnx import OnnxRunner, OnnxValue
import onnx
import numpy as np
import onnxruntime as ort
def get_example_inputs(graph_inputs:dict[str, OnnxValue]):
ret: dict[str, Tensor] = {}
for name, spec in graph_inputs.items():
assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now"
shape = tuple(dim if isinstance(dim, int) else 1 for dim in spec.shape)
value = Tensor(np.random.uniform(size=shape).astype(_to_np_dtype(spec.dtype)) * 8).realize()
ret.update({name:value})
return ret
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
run_onnx = OnnxRunner(onnx.load(onnx_file))
tinygrad_out = run_onnx(inputs)
ort_options = ort.SessionOptions()
ort_options.log_severity_level = 3
ort_sess = ort.InferenceSession(onnx_file, ort_options, ["CPUExecutionProvider"])
np_inputs = {k:v.numpy() if isinstance(v, Tensor) else v for k,v in inputs.items()}
out_names = list(run_onnx.graph_outputs)
out_values = ort_sess.run(out_names, np_inputs)
ort_out = dict(zip(out_names, out_values))
assert len(tinygrad_out) == len(ort_out) and tinygrad_out.keys() == ort_out.keys()
for k in tinygrad_out.keys():
tiny_v, onnx_v = tinygrad_out[k], ort_out[k]
if tiny_v is None: assert tiny_v == onnx_v
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}")