mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add onnx frontend stub [pr] (#9558)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import sys, onnx, time
|
||||
from tinygrad import TinyJit, Device, GlobalCounters, fetch, getenv
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
|
||||
def load_onnx_model(onnx_file):
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
import subprocess
|
||||
import tensorflow as tf
|
||||
import tf2onnx
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.export_model import export_model_clang, compile_net, jit_model
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from extra.onnx import OnnxRunner # TODO: port to main tinygrad
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from ultralytics import YOLO
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx_helpers import get_example_inputs
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
|
||||
from pathlib import Path
|
||||
from extra.onnx import OnnxRunner, get_onnx_ops
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx import get_onnx_ops
|
||||
from extra.onnx_helpers import validate, get_example_inputs
|
||||
|
||||
def get_config(root_path: Path):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from extra.onnx import OnnxRunner, OnnxValue
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from extra.onnx import OnnxValue
|
||||
import onnx
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
@@ -2,7 +2,7 @@ import time, sys, hashlib
|
||||
from pathlib import Path
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad import Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
|
||||
2
test/external/external_model_benchmark.py
vendored
2
test/external/external_model_benchmark.py
vendored
@@ -6,7 +6,7 @@ import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
import onnxruntime as ort
|
||||
from onnx2torch import convert
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.helpers import OSX, DEBUG, fetch
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.device import CompileError
|
||||
|
||||
2
test/external/external_test_onnx_backend.py
vendored
2
test/external/external_test_onnx_backend.py
vendored
@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
|
||||
# pip3 install tabulate
|
||||
pytest_plugins = 'onnx.backend.test.report',
|
||||
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
|
||||
class TinygradModel(BackendRep):
|
||||
def __init__(self, run_onnx, input_names):
|
||||
|
||||
@@ -7,7 +7,7 @@ try:
|
||||
import onnx
|
||||
except ModuleNotFoundError:
|
||||
raise unittest.SkipTest("onnx not installed, skipping onnx test")
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise unittest.SkipTest()
|
||||
from extra.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
out_file = get_quantized_model(sz)
|
||||
onnx_model = onnx.load(out_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
5
tinygrad/frontend/onnx.py
Normal file
5
tinygrad/frontend/onnx.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# type: ignore
|
||||
import sys, pathlib
|
||||
sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
|
||||
try: from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import
|
||||
except ImportError as e: raise ImportError("onnx frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e
|
||||
Reference in New Issue
Block a user