onnx parser (#10435)

* onnx parser

* fix compile, lint

* onnx.load -> onnx_load

* compatible with ModelProto

* fix test external_test_onnx_ops.py

* fix tests

* fix signed int

* reduce to 261 lines

* fix TypeProto.Optional

* debug for _parse_message, add TypeProto.Sequence, cleanup

* onnx_load from Tensor

* remove BufferedReader

* 174 lines and reduce tensor copy

* cleanup

* use onnx_load in external_model_benchmark.py

* fix qcom test

* [onnx] parser support external data

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2025-06-10 00:44:28 +08:00
committed by GitHub
parent cfa65bea05
commit 24d328e313
13 changed files with 273 additions and 50 deletions

View File

@@ -7,7 +7,7 @@ try:
import onnx
except ModuleNotFoundError:
raise unittest.SkipTest("onnx not installed, skipping onnx test")
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
@@ -25,7 +25,7 @@ np.random.seed(1337)
class TestOnnxModel(unittest.TestCase):
def test_benchmark_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
run_onnx = OnnxRunner(onnx_model)
def get_inputs():
np_inputs = {
@@ -69,7 +69,7 @@ class TestOnnxModel(unittest.TestCase):
ps.print_stats(30)
def test_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
run_onnx = OnnxRunner(onnx_model)
print("got run_onnx")
inputs = {
@@ -93,6 +93,7 @@ class TestOnnxModel(unittest.TestCase):
et = time.monotonic()
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
print(tinygrad_out, torch_out)
np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2)
@@ -119,7 +120,7 @@ class TestOnnxModel(unittest.TestCase):
input_name, input_new)
def _test_model(self, fn, input_name, input_new, debug=False):
onnx_model = onnx.load(fn)
onnx_model = onnx_load(fn)
print("onnx loaded")
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
run_onnx = OnnxRunner(onnx_model)