diff --git a/examples/benchmark_onnx.py b/examples/benchmark_onnx.py index b35a8c4be4..ad7c1ebb18 100644 --- a/examples/benchmark_onnx.py +++ b/examples/benchmark_onnx.py @@ -1,11 +1,10 @@ import sys, time from tinygrad import TinyJit, GlobalCounters, fetch, getenv -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from extra.onnx_helpers import get_example_inputs, validate def load_onnx_model(onnx_file): - onnx_model = onnx_load(onnx_file) - run_onnx = OnnxRunner(onnx_model) + run_onnx = OnnxRunner(onnx_file) run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True) return run_onnx_jit, run_onnx.graph_inputs diff --git a/examples/compile_tensorflow.py b/examples/compile_tensorflow.py index 3ad6f31195..1f308c58e3 100644 --- a/examples/compile_tensorflow.py +++ b/examples/compile_tensorflow.py @@ -25,7 +25,7 @@ class TinyOnnx: def __init__(self, keras_model): input_signature = [tf.TensorSpec([1,32], tf.float32, name='x')] onnx_model, _ = tf2onnx.convert.from_keras(keras_model, input_signature, opset=13) - self.run_onnx = OnnxRunner(onnx_model) + self.run_onnx = OnnxRunner(Tensor(onnx_model.SerializeToString(), device="PYTHON")) def forward(self, x): return self.run_onnx({"x": x}, debug=False)['predictions'] diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 9f284ed6f2..0b2fb40259 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -5,29 +5,26 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0" -from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device +from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device, dtypes from tinygrad.helpers import DEBUG, getenv -from tinygrad.tensor import _from_np_dtype from tinygrad.engine.realize import CompiledRunner import onnx -from onnx.helper import tensor_dtype_to_np_dtype -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx" OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" def compile(onnx_file): - onnx_model = onnx_load(onnx_file) - run_onnx = OnnxRunner(onnx_model) + run_onnx = OnnxRunner(onnx_file) print("loaded model") - input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} + input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()} + input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()} # Float inputs and outputs to tinyjits for openpilot are always float32 - input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()} + input_types = {k:(dtypes.float32 if v is dtypes.float16 else v) for k,v in input_types.items()} Tensor.manual_seed(100) - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())} + new_inputs = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in sorted(input_shapes.items())} new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} print("created tensors") diff --git a/examples/openpilot/compile4.py b/examples/openpilot/compile4.py index 867f1df54c..658f9a07d2 100644 --- a/examples/openpilot/compile4.py +++ b/examples/openpilot/compile4.py @@ -1,4 +1,4 @@ -import sys, onnx +import sys from tinygrad import Tensor, fetch, GlobalCounters, dtypes from tinygrad.uop.ops import UOp from tinygrad.frontend.onnx import OnnxRunner @@ -12,10 +12,8 @@ OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/comm OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl" if __name__ == "__main__": - fn = fetch(OPENPILOT_MODEL) onnx_file = fetch(OPENPILOT_MODEL) - onnx_model = onnx.load(onnx_file) - run_onnx = OnnxRunner(onnx_model) + run_onnx = OnnxRunner(onnx_file) inputs = run_onnx.get_empty_input_data("npy", dtypes.float32) out: Tensor = next(iter(run_onnx({k:v.to(None) for k,v in inputs.items()}).values())).to('cpu') diff --git a/examples/yolov8-onnx.py b/examples/yolov8-onnx.py index 3b9bdfba9c..bc3d50ab9e 100644 --- a/examples/yolov8-onnx.py +++ b/examples/yolov8-onnx.py @@ -2,13 +2,12 @@ import os from ultralytics import YOLO from pathlib import Path -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from extra.onnx_helpers import get_example_inputs os.chdir("/tmp") if not Path("yolov8n-seg.onnx").is_file(): model = YOLO("yolov8n-seg.pt") model.export(format="onnx", imgsz=[480,640]) -onnx_model = onnx_load(open("yolov8n-seg.onnx", "rb")) -run_onnx = OnnxRunner(onnx_model) +run_onnx = OnnxRunner("yolov8n-seg.onnx") run_onnx(get_example_inputs(run_onnx.graph_inputs), debug=True) diff --git a/extra/huggingface_onnx/run_models.py b/extra/huggingface_onnx/run_models.py index af920f8cf0..7ac9ef54e1 100644 --- a/extra/huggingface_onnx/run_models.py +++ b/extra/huggingface_onnx/run_models.py @@ -1,6 +1,6 @@ import onnx, yaml, tempfile, time, collections, pprint, argparse, json from pathlib import Path -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from extra.onnx import get_onnx_ops from extra.onnx_helpers import validate, get_example_inputs @@ -13,8 +13,7 @@ def get_config(root_path: Path): return ret def run_huggingface_validate(onnx_model_path, config, rtol, atol): - onnx_model = onnx_load(onnx_model_path) - onnx_runner = OnnxRunner(onnx_model) + onnx_runner = OnnxRunner(onnx_model_path) inputs = get_example_inputs(onnx_runner.graph_inputs, config) validate(onnx_model_path, inputs, rtol=rtol, atol=atol) @@ -46,7 +45,7 @@ def retrieve_op_stats(models:dict[str, tuple[Path, Path]]) -> dict: for model_id, (root_path, relative_path) in models.items(): print(f"examining {model_id}") model_path = root_path / relative_path - onnx_runner = OnnxRunner(onnx.load(model_path)) + onnx_runner = OnnxRunner(model_path) for node in onnx_runner.graph_nodes: op_counter[node.op] += 1 if node.op not in supported_ops: diff --git a/extra/onnx.py b/extra/onnx.py index 098cc7842b..e1d965910e 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,10 +1,11 @@ from types import SimpleNamespace from typing import Any, Sequence, cast, Literal, Callable -import dataclasses, functools, io, math, types, warnings, sys +import dataclasses, functools, io, math, types, warnings, pathlib, sys from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype from tinygrad.device import is_dtype_supported, Device +from extra.onnx_parser import onnx_load # https://github.com/onnx/onnx/blob/rel-1.17.0/onnx/onnx.proto3#L500-L544 data_types: dict[int, DType] = { @@ -24,7 +25,7 @@ attribute_types: dict[int, Callable] = { } # ***** protobuf parsing ****** -from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper +from onnx import AttributeProto, TensorProto, TypeProto, helper import numpy as np def has_field(onnx_type: TypeProto|SimpleNamespace, field): @@ -132,8 +133,14 @@ def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes: debug = int(getenv("DEBUGONNX", "0")) limit = int(getenv("ONNXLIMIT", "-1")) class OnnxRunner: - def __init__(self, model: ModelProto|SimpleNamespace): - # parse model protobuf + """ + `OnnxRunner` executes an ONNX model using Tinygrad. + + Args: + model_path: The ONNX model, provided as a file path (a string or Path object) or a Tensor. + """ + def __init__(self, model_path: Tensor | str | pathlib.Path): + model = onnx_load(model_path) self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node) self.old_training = Tensor.training Tensor.training = True if self.is_training else False @@ -178,6 +185,12 @@ class OnnxRunner: def get_empty_input_data(self, device:str|None=None, dtype:DType|None=None) -> dict[str, Tensor]: return {name:Tensor.empty(*spec.shape, device=device, dtype=dtype or spec.dtype) for name, spec in self.graph_inputs.items()} + def to(self, device:str|None): + self.graph_values = {k:v.to(device) if isinstance(v, Tensor) else v for k,v in self.graph_values.items()} + self.graph_nodes = tuple(OnnxNode(n.num, n.op, tuple(n.inputs), tuple(n.outputs), + {k:v.to(device) if isinstance(v, Tensor) else v for k,v in n.opts.items()}) for n in self.graph_nodes) + return self + def __call__(self, inputs:dict[str, Any], debug=debug): for name, input_spec in self.graph_inputs.items(): if name not in inputs: raise RuntimeError(f"Please provide input data for {name}") diff --git a/extra/onnx_helpers.py b/extra/onnx_helpers.py index d2a8bb9616..750f54d9e3 100644 --- a/extra/onnx_helpers.py +++ b/extra/onnx_helpers.py @@ -1,6 +1,6 @@ from tinygrad import Tensor from tinygrad.tensor import _to_np_dtype -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from extra.onnx import OnnxValue import numpy as np import onnxruntime as ort @@ -46,7 +46,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}): return ret def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5): - run_onnx = OnnxRunner(onnx_load(onnx_file)) + run_onnx = OnnxRunner(onnx_file) ort_options = ort.SessionOptions() ort_options.log_severity_level = 3 diff --git a/test/external/external_benchmark_openpilot.py b/test/external/external_benchmark_openpilot.py index de856b2f83..66a887d92c 100644 --- a/test/external/external_benchmark_openpilot.py +++ b/test/external/external_benchmark_openpilot.py @@ -1,24 +1,21 @@ import time, sys, hashlib from pathlib import Path -from onnx.helper import tensor_dtype_to_np_dtype -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from tinygrad import Tensor, dtypes, TinyJit from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange -from tinygrad.tensor import _from_np_dtype import numpy as np from extra.bench_log import BenchEvent, WallTimeEvent OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" if __name__ == "__main__": - onnx_model = onnx_load(onnx_path := fetch(OPENPILOT_MODEL)) - run_onnx = OnnxRunner(onnx_model) + run_onnx = OnnxRunner(fetch(OPENPILOT_MODEL)) Tensor.manual_seed(100) - input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} - new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in input_shapes.items()} - new_inputs_junk = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in input_shapes.items()} + input_shapes = {name: spec.shape for name, spec in run_onnx.graph_inputs.items()} + input_types = {name: spec.dtype for name, spec in run_onnx.graph_inputs.items()} + new_inputs = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in input_shapes.items()} + new_inputs_junk = {k:Tensor.randn(*shp, dtype=input_types[k]).mul(8).realize() for k,shp in input_shapes.items()} new_inputs_junk_numpy = {k:v.numpy() for k,v in new_inputs_junk.items()} # benchmark diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 84f5299d7b..b29892f2d9 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -2,11 +2,11 @@ import csv, pathlib, time import numpy as np import torch torch.set_num_threads(1) -from onnx.helper import tensor_dtype_to_np_dtype import onnxruntime as ort from onnx2torch import convert -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from tinygrad.helpers import OSX, DEBUG, fetch, getenv +from tinygrad.dtype import _to_np_dtype from tinygrad import Tensor, Device, dtypes MODELS = { @@ -50,20 +50,19 @@ def benchmark_model(m, devices, validate_outs=False): CSV = {"model": m} fn = fetch(MODELS[m]) - onnx_model = onnx_load(fn) - output_names = [out.name for out in onnx_model.graph.output] - excluded = {inp.name for inp in onnx_model.graph.initializer} - input_shapes = {inp.name:tuple(x.dim_value if hasattr(x, "dim_value") and x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501 - input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded} - np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()} + runner = OnnxRunner(fn) + output_names = runner.graph_outputs + input_shapes = {name: tuple(s if isinstance(s, int) and s != 0 else 1 for s in spec.shape) for name, spec in runner.graph_inputs.items()} + input_types = {name: spec.dtype for name, spec in runner.graph_inputs.items()} + np_inputs = {k:torch.randn(shp).numpy().astype(_to_np_dtype(input_types[k])) for k,shp in input_shapes.items()} assert len(input_shapes) < 30, f"too many input shapes {len(input_shapes)}" # print input names - if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded]) + if DEBUG >= 2: print(list(runner.graph_inputs)) for device in devices: Device.DEFAULT = device inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = OnnxRunner(onnx_model) + tinygrad_model = runner.to(device) benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()}) from tinygrad.engine.jit import TinyJit @@ -107,12 +106,12 @@ def benchmark_model(m, devices, validate_outs=False): rtol, atol = 2e-3, 2e-3 # tolerance for fp16 models Device.DEFAULT = device # force half inputs to float for numerical stability when validating - # this will reply on automatic dtype promotion for converting half weights inside the graph + # this will rely on automatic dtype promotion for converting half weights inside the graph if m in half_models: inputs = {k:Tensor(inp, dtype=dtypes.float32) if inp.dtype == np.float16 else Tensor(inp) for k,inp in np_inputs.items()} else: inputs = {k:Tensor(inp) for k,inp in np_inputs.items()} - tinygrad_model = OnnxRunner(onnx_model) + tinygrad_model = runner.to(device) tinygrad_out = tinygrad_model(inputs) ort_sess = ort.InferenceSession(str(fn), ort_options, ["CPUExecutionProvider"]) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 3cca77a993..b93faae80f 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -6,12 +6,11 @@ import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.helpers import getenv, OSX from tinygrad.device import is_dtype_supported +from tinygrad.frontend.onnx import OnnxRunner # pip3 install tabulate pytest_plugins = 'onnx.backend.test.report', -from tinygrad.frontend.onnx import OnnxRunner, onnx_load - class TinygradModel(BackendRep): def __init__(self, run_onnx, input_names): super().__init__() @@ -31,7 +30,7 @@ class TinygradBackend(Backend): net_feed_input = [x for x in input_all if x not in input_initializer] print("prepare", cls, device, net_feed_input) model = Tensor(model.SerializeToString(), device="PYTHON") - run_onnx = OnnxRunner(onnx_load(model)) + run_onnx = OnnxRunner(model) return TinygradModel(run_onnx, net_feed_input) @classmethod diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 8150995d50..d427196b96 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -4,7 +4,7 @@ from typing import Any import unittest, onnx, tempfile -from tinygrad import dtypes +from tinygrad import dtypes, Tensor from tinygrad.frontend.onnx import OnnxRunner import numpy as np from extra.onnx_helpers import validate @@ -88,7 +88,8 @@ class TestMainOnnxOps(TestOnnxOps): attributes = {"detect_negative":1, "detect_positive":1} outputs = ["y"] model = self.helper_build_model("IsInf", inputs, attributes, outputs) - outputs = OnnxRunner(model)(inputs) + runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON")) + outputs = runner(inputs) assert outputs["y"].dtype is dtypes.bool def test_quantize_linear(self): @@ -203,7 +204,7 @@ class TestTrainingOnnxOps(TestOnnxOps): def _validate_training(self, op:str, onnx_fxn, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str]): model = self.helper_build_model(op, inps, opts, outs) if op == "Momentum": del opts['mode'] - runner = OnnxRunner(model) + runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON")) tiny_out = runner(inps) onnx_out = onnx_fxn(**inps, **opts) for (nm, t_out), o_out in zip(tiny_out.items(), onnx_out): diff --git a/test/external/external_test_onnx_runner.py b/test/external/external_test_onnx_runner.py index d804542add..6c44847242 100644 --- a/test/external/external_test_onnx_runner.py +++ b/test/external/external_test_onnx_runner.py @@ -1,8 +1,8 @@ -import unittest, onnx, tempfile -from tinygrad import dtypes -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +import unittest, onnx +from tinygrad import dtypes, Tensor from tinygrad.device import is_dtype_supported from extra.onnx import data_types +from tinygrad.frontend.onnx import OnnxRunner from hypothesis import given, settings, strategies as st import numpy as np @@ -17,11 +17,7 @@ class TestOnnxRunnerDtypes(unittest.TestCase): node = onnx.helper.make_node('Identity', inputs=['input'], outputs=['output']) graph = onnx.helper.make_graph([node], 'identity_test', [input_tensor], [output_tensor]) model = onnx.helper.make_model(graph) - tmp = tempfile.NamedTemporaryFile(suffix='.onnx') - onnx.save(model, tmp.name) - tmp.flush() - model = onnx_load(tmp.name) - runner = OnnxRunner(model) + runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON")) self.assertEqual(len(runner.graph_inputs), 1) self.assertEqual(runner.graph_inputs['input'].dtype, tinygrad_dtype) @@ -33,11 +29,7 @@ class TestOnnxRunnerDtypes(unittest.TestCase): node = onnx.helper.make_node('Identity', inputs=['input'], outputs=['output']) graph = onnx.helper.make_graph([node], 'identity_test', [input_tensor], [output_tensor], [initializer]) model = onnx.helper.make_model(graph) - tmp = tempfile.NamedTemporaryFile(suffix='.onnx') - onnx.save(model, tmp.name) - tmp.flush() - model = onnx_load(tmp.name) - runner = OnnxRunner(model) + runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON")) self.assertEqual(len(runner.graph_inputs), 1) self.assertEqual(runner.graph_values['initializer'].dtype, tinygrad_dtype) @@ -48,11 +40,7 @@ class TestOnnxRunnerDtypes(unittest.TestCase): node = onnx.helper.make_node('Constant', inputs=[], outputs=['output'], value=value_tensor) graph = onnx.helper.make_graph([node], 'attribute_test', [], [output_tensor]) model = onnx.helper.make_model(graph) - tmp = tempfile.NamedTemporaryFile(suffix='.onnx') - tmp.flush() - onnx.save(model, tmp.name) - model = onnx_load(tmp.name) - runner = OnnxRunner(model) + runner = OnnxRunner(Tensor(model.SerializeToString(), device="PYTHON")) self.assertEqual(runner.graph_nodes[0].opts['value'].dtype, tinygrad_dtype) @settings(deadline=1000) # TODO investigate unreliable timing diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index da62c49ff7..432f3b20a5 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -7,7 +7,7 @@ try: import onnx except ModuleNotFoundError: raise unittest.SkipTest("onnx not installed, skipping onnx test") -from tinygrad.frontend.onnx import OnnxRunner, onnx_load +from tinygrad.frontend.onnx import OnnxRunner from tinygrad.tensor import Tensor from tinygrad.helpers import CI, fetch, temp @@ -25,7 +25,7 @@ np.random.seed(1337) class TestOnnxModel(unittest.TestCase): def test_benchmark_openpilot_model(self): - onnx_model = onnx_load(fetch(OPENPILOT_MODEL)) + onnx_model = fetch(OPENPILOT_MODEL) run_onnx = OnnxRunner(onnx_model) def get_inputs(): np_inputs = { @@ -69,7 +69,7 @@ class TestOnnxModel(unittest.TestCase): ps.print_stats(30) def test_openpilot_model(self): - onnx_model = onnx_load(fetch(OPENPILOT_MODEL)) + onnx_model = fetch(OPENPILOT_MODEL) run_onnx = OnnxRunner(onnx_model) print("got run_onnx") inputs = { @@ -121,10 +121,9 @@ class TestOnnxModel(unittest.TestCase): input_name, input_new) def _test_model(self, fn, input_name, input_new, debug=False): - onnx_model = onnx_load(fn) + run_onnx = OnnxRunner(fn) print("onnx loaded") from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS - run_onnx = OnnxRunner(onnx_model) def run(img): inputs = {input_name: preprocess(img, new=input_new)} diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 7da8c769c6..20b3062c2d 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -70,10 +70,9 @@ class TestQuantizeOnnxCPU(unittest.TestCase): import onnx # noqa: F401 # pylint: disable=unused-import except ImportError: raise unittest.SkipTest() - from tinygrad.frontend.onnx import OnnxRunner, onnx_load + from tinygrad.frontend.onnx import OnnxRunner out_file = get_quantized_model(sz) - onnx_model = onnx_load(out_file) - run_onnx = OnnxRunner(onnx_model) + run_onnx = OnnxRunner(out_file) inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)) with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1): sched = run_onnx({"input":inp})["output"].schedule() diff --git a/tinygrad/frontend/onnx.py b/tinygrad/frontend/onnx.py index db34a70be8..2d6703bf1f 100644 --- a/tinygrad/frontend/onnx.py +++ b/tinygrad/frontend/onnx.py @@ -3,5 +3,4 @@ import sys, pathlib sys.path.append(pathlib.Path(__file__).parent.parent.as_posix()) try: from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import - from extra.onnx_parser import onnx_load # noqa: F401 # pylint: disable=unused-import except ImportError as e: raise ImportError("onnx frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e \ No newline at end of file