mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
onnx parser (#10435)
* onnx parser * fix compile, lint * onnx.load -> onnx_load * compatible with ModelProto * fix test external_test_onnx_ops.py * fix tests * fix signed int * reduce to 261 lines * fix TypeProto.Optional * debug for _parse_message, add TypeProto.Sequence, cleanup * onnx_load from Tensor * remove BufferedReader * 174 lines and reduce tensor copy * cleanup * use onnx_load in external_model_benchmark.py * fix qcom test * [onnx] parser support external data --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -7,7 +7,7 @@ try:
|
||||
import onnx
|
||||
except ModuleNotFoundError:
|
||||
raise unittest.SkipTest("onnx not installed, skipping onnx test")
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
|
||||
@@ -25,7 +25,7 @@ np.random.seed(1337)
|
||||
|
||||
class TestOnnxModel(unittest.TestCase):
|
||||
def test_benchmark_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
def get_inputs():
|
||||
np_inputs = {
|
||||
@@ -69,7 +69,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
ps.print_stats(30)
|
||||
|
||||
def test_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
print("got run_onnx")
|
||||
inputs = {
|
||||
@@ -93,6 +93,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
et = time.monotonic()
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2)
|
||||
@@ -119,7 +120,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
input_name, input_new)
|
||||
|
||||
def _test_model(self, fn, input_name, input_new, debug=False):
|
||||
onnx_model = onnx.load(fn)
|
||||
onnx_model = onnx_load(fn)
|
||||
print("onnx loaded")
|
||||
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
Reference in New Issue
Block a user