onnx parser (#10435)

* onnx parser

* fix compile, lint

* onnx.load -> onnx_load

* compatible with ModelProto

* fix test external_test_onnx_ops.py

* fix tests

* fix signed int

* reduce to 261 lines

* fix TypeProto.Optional

* debug for _parse_message, add TypeProto.Sequence, cleanup

* onnx_load from Tensor

* remove BufferedReader

* 174 lines and reduce tensor copy

* cleanup

* use onnx_load in external_model_benchmark.py

* fix qcom test

* [onnx] parser support external data

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2025-06-10 00:44:28 +08:00
committed by GitHub
parent cfa65bea05
commit 24d328e313
13 changed files with 273 additions and 50 deletions

View File

@@ -1,10 +1,10 @@
import sys, onnx, time, pickle
import sys, time, pickle
from tinygrad import TinyJit, GlobalCounters, fetch, getenv
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from extra.onnx_helpers import get_example_inputs, validate
def load_onnx_model(onnx_file):
onnx_model = onnx.load(onnx_file)
onnx_model = onnx_load(onnx_file)
run_onnx = OnnxRunner(onnx_model)
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
return run_onnx_jit, run_onnx.graph_inputs

View File

@@ -12,13 +12,13 @@ from tinygrad.engine.realize import CompiledRunner
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
def compile(onnx_file):
onnx_model = onnx.load(onnx_file)
onnx_model = onnx_load(onnx_file)
run_onnx = OnnxRunner(onnx_model)
print("loaded model")

View File

@@ -1,9 +1,8 @@
#!/usr/bin/env python3
import os
from ultralytics import YOLO
import onnx
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from extra.onnx_helpers import get_example_inputs
from tinygrad.tensor import Tensor
@@ -11,6 +10,6 @@ os.chdir("/tmp")
if not Path("yolov8n-seg.onnx").is_file():
model = YOLO("yolov8n-seg.pt")
model.export(format="onnx", imgsz=[480,640])
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
onnx_model = onnx_load(open("yolov8n-seg.onnx", "rb"))
run_onnx = OnnxRunner(onnx_model)
run_onnx(get_example_inputs(run_onnx.graph_inputs), debug=True)

View File

@@ -1,6 +1,6 @@
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
from pathlib import Path
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from extra.onnx import get_onnx_ops
from extra.onnx_helpers import validate, get_example_inputs
@@ -13,7 +13,7 @@ def get_config(root_path: Path):
return ret
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
onnx_model = onnx.load(onnx_model_path)
onnx_model = onnx_load(onnx_model_path)
onnx_runner = OnnxRunner(onnx_model)
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)

View File

@@ -1,14 +1,19 @@
from types import SimpleNamespace
from typing import Any, Sequence, cast, Literal, Callable
import dataclasses, functools, io, math, types
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
from tinygrad.device import is_dtype_supported
from tinygrad.device import is_dtype_supported, Device
# ***** protobuf parsing ******
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
import numpy as np
def has_field(onnx_type: TypeProto|SimpleNamespace, field):
if isinstance(onnx_type, TypeProto): return onnx_type.HasField(field)
return hasattr(onnx_type, field)
def dtype_parse(onnx_dtype: int) -> DType:
supported: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
@@ -26,9 +31,10 @@ def dtype_parse(onnx_dtype: int) -> DType:
def attribute_parse(onnx_attribute: AttributeProto):
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.STRING: lambda a: a.s.data().tobytes().decode("utf8") if isinstance(a.s, Tensor) else a.s.decode("utf8"),
AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
AttributeProto.STRINGS: lambda a: tuple(x.data().tobytes().decode("utf8") for x in a.strings)
}
unsupported = {
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
@@ -41,24 +47,35 @@ def attribute_parse(onnx_attribute: AttributeProto):
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
list(onnx_tensor.uint64_data):
if len(data) == 1: return Tensor(data[0], dtype=dtype).reshape(shape)
return Tensor(data, dtype=dtype).reshape(shape).realize()
if onnx_tensor.HasField("raw_data"):
np_buffer = np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
return Tensor(np_buffer, dtype=dtype)
data = None
if len(onnx_tensor.float_data): data = onnx_tensor.float_data
elif len(onnx_tensor.int32_data): data = onnx_tensor.int32_data
elif len(onnx_tensor.int64_data): data = onnx_tensor.int64_data
elif len(onnx_tensor.double_data): data = onnx_tensor.double_data
elif len(onnx_tensor.uint64_data): data = onnx_tensor.uint64_data
if isinstance(data, Tensor):
if len(data) == 1: return Tensor(data.tolist()[0], dtype=dtype).reshape(shape)
return data.cast(dtype).reshape(shape).to(Device.DEFAULT)
if has_field(onnx_tensor, "raw_data"):
if onnx_tensor.data_type == TensorProto.FLOAT16:
np_buffer = np.frombuffer(onnx_tensor.raw_data.data().tobytes(),
dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
return Tensor(np_buffer, dtype=dtype)
ret = onnx_tensor.raw_data.bitcast(dtype).reshape(shape).to(Device.DEFAULT)
if shape == (): ret = Tensor(ret.item(), dtype=dtype).reshape(shape)
return ret
return Tensor(None)
def type_parse(onnx_type: TypeProto):
elem_type = onnx_type
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
if has_field(elem_type, "map_type") or has_field(elem_type, "sparse_tensor_type") or has_field(elem_type, "opaque_type"):
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
if elem_type.HasField("tensor_type"):
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
if is_optional := has_field(elem_type, "optional_type"): elem_type = elem_type.optional_type.elem_type
if is_sequence := has_field(elem_type, "sequence_type"): elem_type = elem_type.sequence_type.elem_type
if has_field(elem_type, "tensor_type"):
shape = tuple(getattr(d, "dim_param", None) or getattr(d, "dim_value") for d in elem_type.tensor_type.shape.dim) \
if has_field(elem_type.tensor_type, "shape") else None # test_identity_sequence_cpu
dtype = dtype_parse(elem_type.tensor_type.elem_type)
return OnnxValue(shape, dtype, is_optional, is_sequence)
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
@@ -109,7 +126,7 @@ def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
debug = int(getenv("DEBUGONNX", "0"))
limit = int(getenv("ONNXLIMIT", "-1"))
class OnnxRunner:
def __init__(self, model: ModelProto):
def __init__(self, model: ModelProto|SimpleNamespace):
# parse model protobuf
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
self.old_training = Tensor.training

View File

@@ -1,8 +1,7 @@
from tinygrad import Tensor
from tinygrad.tensor import _to_np_dtype
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from extra.onnx import OnnxValue
import onnx
import numpy as np
import onnxruntime as ort
@@ -47,7 +46,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
return ret
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
run_onnx = OnnxRunner(onnx.load(onnx_file))
run_onnx = OnnxRunner(onnx_load(onnx_file))
ort_options = ort.SessionOptions()
ort_options.log_severity_level = 3

204
extra/onnx_parser.py Normal file
View File

@@ -0,0 +1,204 @@
# https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
import os, pathlib, struct
from io import BufferedReader
from typing import Tuple, Union
from types import SimpleNamespace
from tinygrad.nn.state import TensorIO
from tinygrad.tensor import Tensor, dtypes
# Protobuf Wire Types
WIRETYPE_VARINT = 0; WIRETYPE_FIXED64 = 1; WIRETYPE_LENGTH_DELIMITED = 2; WIRETYPE_START_GROUP = 3; WIRETYPE_END_GROUP = 4; WIRETYPE_FIXED32 = 5 # noqa: E702
# TensorProto.DataType
class TensorDataType:
UNDEFINED = 0; FLOAT = 1; UINT8 = 2; INT8 = 3; UINT16 = 4; INT16 = 5; INT32 = 6; INT64 = 7 # noqa: E702
STRING = 8; BOOL = 9; FLOAT16 = 10; DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; COMPLEX128 = 15; BFLOAT16 = 16 # noqa: E702
# AttributeProto.AttributeType
class AttributeType:
UNDEFINED = 0; FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; GRAPH = 5; SPARSE_TENSOR = 11; TYPE_PROTO = 13; FLOATS = 6; INTS = 7 # noqa: E702
STRINGS = 8; TENSORS = 9; GRAPHS = 10; SPARSE_TENSORS = 12; TYPE_PROTOS = 14 # noqa: E702
class PBType: FLOAT = 1; INT = 2; STRING = 3; FLOATS = 4; INTS = 5; STRINGS = 6; BYTES = 7; SUB = 8 # noqa: E702
PB_INFOS = {
"OperatorSetIdProto": {1: ("domain", PBType.STRING), 2: ("version", PBType.INT)},
"StringStringEntryProto": {1: ("key", PBType.STRING), 2: ("value", PBType.STRING)},
"TensorProto": {1: ("dims", PBType.INT, True), 2: ("data_type", PBType.INT), 4: ("float_data", PBType.FLOATS),
13: ("external_data", PBType.SUB, True, "StringStringEntryProto"), 14: ("data_location", PBType.INT),
5: ("int32_data", PBType.INTS), 7: ("int64_data", PBType.INTS), 8: ("name", PBType.STRING), 9: ("raw_data", PBType.BYTES)},
"TensorShapeProtoDimension": {1: ("dim_value", PBType.INT), 2: ("dim_param", PBType.STRING)},
"TensorShapeProto": {1: ("dim", PBType.SUB, True, "TensorShapeProtoDimension")},
"ModelProto": {1: ("ir_version", PBType.INT), 5: ("model_version", PBType.INT),
2: ("producer_name", PBType.STRING), 3: ("producer_version", PBType.STRING), 4: ("domain", PBType.STRING), 6: ("doc_string", PBType.STRING),
7: ("graph", PBType.SUB, False, ("GraphProto", lambda: {"node": [], "initializer": [], "input": [], "output": [], "value_info": []})),
8: ("opset_import",PBType.SUB, True, "OperatorSetIdProto")},
"GraphProto": {2: ("name", PBType.STRING), 10: ("doc_string", PBType.STRING),
1: ("node", PBType.SUB, True, ("NodeProto", lambda: {"input": [], "output": [], "attribute": [], "domain": None})),
5: ("initializer", PBType.SUB, True, ("TensorProto", lambda: {"dims": [], "float_data": [], "int32_data": [], "string_data": [],
"int64_data": [], "double_data": [], "uint64_data": []})),
11: ("input", PBType.SUB, True, "ValueInfoProto"), 12: ("output", PBType.SUB, True, "ValueInfoProto")},
"NodeProto": { 1: ("input", PBType.STRING, True), 2: ("output", PBType.STRING, True), 3: ("name", PBType.STRING),
4: ("op_type", PBType.STRING), 6: ("doc_string", PBType.STRING), 7: ("domain", PBType.STRING),
5: ("attribute", PBType.SUB, True, ("AttributeProto", lambda: {"floats": [], "ints": [], "strings": []}))},
"AttributeProto": {1: ("name", PBType.STRING), 20: ("type", PBType.INT), 3: ("i", PBType.INT), 8: ("ints", PBType.INT, True),
2: ("f", PBType.FLOAT), 7: ("floats", PBType.FLOAT, True), 4: ("s", PBType.BYTES), 9: ("strings", PBType.BYTES, True),
5:("t", PBType.SUB, False, ("TensorProto", lambda: {"dims": [], "float_data": [], "int32_data": [], "string_data": [], "int64_data": [],
"double_data": [], "uint64_data": []}))},
"ValueInfoProto": {1: ("name", PBType.STRING), 2: ("type", PBType.SUB, False, "TypeProto"), 3: ("doc_string", PBType.STRING)},
"TypeProto": {1: ("tensor_type", PBType.SUB, False, "TypeProtoTensor"), 4: ("sequence_type", PBType.SUB, False, "TypeProtoSequence"),
9: ("optional_type", PBType.SUB, False, "TypeProtoOptional"), 6: ("denotation", PBType.STRING)},
"TypeProtoSequence": {1: ("elem_type", PBType.SUB, False, "TypeProto")},
"TypeProtoOptional": {1: ("elem_type", PBType.SUB, False, "TypeProto")},
"TypeProtoTensor": {1: ("elem_type", PBType.INT), 2: ("shape", PBType.SUB, False, ("TensorShapeProto", lambda: {"dim": []}))},
}
def onnx_load(fn: Union[Tensor, str, pathlib.Path], load_external_data: bool=True):
parser = OnnxParser(fn, load_external_data)
onnx_model = parser.parse()
model = dict_to_namespace(onnx_model)
return model
def gen_result(obj: dict, key_name, val, repeated: bool):
if repeated: obj.setdefault(key_name, []).append(val)
else: obj[key_name] = val
def dict_to_namespace(d):
if isinstance(d, dict): return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
elif isinstance(d, list): return [dict_to_namespace(i) for i in d]
return d
class OnnxParser:
def __init__(self, inp: Union[Tensor, str, pathlib.Path], load_external_data: bool=True):
self.file_path: Union[pathlib.Path, None] = None
self.load_external_data = load_external_data
if not isinstance(inp, Tensor):
self.file_path = pathlib.Path(inp)
self.tensor = Tensor(self.file_path)
else: self.tensor = inp
self.attr_func_dict = { PBType.BYTES: self._handle_bytes, PBType.SUB: self._handle_sub_message, PBType.FLOATS: self._handle_packed_floats,
PBType.INT: self._handle_int64, PBType.INTS: self._handle_packed_int64s, PBType.STRING: self._handle_string, PBType.FLOAT: self._handle_float}
self.registered_handles = {}
for pb_name in PB_INFOS:
res = {}
for fid, config in PB_INFOS[pb_name].items():
parser_fn, repeated = None, False
if len(config) == 2: name, attr = config
elif len(config) == 3: name, attr, repeated = config
elif len(config) == 4: name, attr, repeated, parser_fn = config
handler_fn = self.attr_func_dict[attr]
def _wrapper_handler(obj, reader, wt, h=handler_fn, n=name, p=parser_fn, r=repeated): return h(obj, n, reader, wt, parser_func=p, repeated=r)
_wrapper_handler._debug_info = f"{fid}, {name} => {handler_fn}"
res[fid] = _wrapper_handler
self.registered_handles[pb_name] = res
def parse(self):
reader = BufferedReader(TensorIO(self.tensor))
return self._parse_message(reader, "ModelProto", lambda: {"opset_import": [], "domain": None, "graph": None})
def decode_varint(self, reader: BufferedReader) -> int:
result = 0
shift = 0
while True:
data = reader.read(1)
if data == b"": raise EOFError("decode_varint EOF")
result |= (data[0] & 0x7F) << shift
if not (data[0] & 0x80): return result
shift += 7
if shift >= 70: raise ValueError("Varint too long")
def skip_field_value(self, reader: BufferedReader, wire_type):
if wire_type == WIRETYPE_VARINT: self.decode_varint(reader)
elif wire_type == WIRETYPE_FIXED64: reader.seek(8, os.SEEK_CUR)
elif wire_type == WIRETYPE_FIXED32: reader.seek(4, os.SEEK_CUR)
elif wire_type == WIRETYPE_LENGTH_DELIMITED: reader.seek(self.decode_varint(reader), os.SEEK_CUR)
else: raise ValueError(f"Unknown wire type: {wire_type}")
def _parse_message(self, reader, message_field_handlers_name, initial_obj_factory=lambda: {}):
message_field_handlers = self.registered_handles[message_field_handlers_name]
obj = initial_obj_factory()
while True:
try:
tag_val = self.decode_varint(reader)
field_number = tag_val >> 3
wire_type = tag_val & 0x07
if handler := message_field_handlers.get(field_number):
handler(obj, reader, wire_type)
else: self.skip_field_value(reader, wire_type)
except EOFError: break
if message_field_handlers_name == "TensorProto" and self.load_external_data and obj.get("data_location", 0) == 1: self._parse_external_data(obj)
return obj
def _handle_delimited(self, reader:BufferedReader, use_tensor=False) -> Tuple[bytes, Tensor]:
str_len = self.decode_varint(reader)
if not use_tensor: return reader.read(str_len)
res = reader.raw._tensor[reader.tell():(reader.tell()+str_len)]
reader.seek(str_len, os.SEEK_CUR)
return res
def _handle_string(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for string field '{key_name}'")
value = self._handle_delimited(reader)
gen_result(obj, key_name, value.decode("utf-8"), repeated)
def _handle_bytes(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for bytes field '{key_name}'")
value = self._handle_delimited(reader, use_tensor=True)
gen_result(obj, key_name, value, repeated)
def _handle_int64(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_VARINT: raise ValueError(f"Expected varint for int64 field '{key_name}'")
val = self.decode_varint(reader)
gen_result(obj, key_name, val - 2**64 if val & (1 << 63) else val, repeated)
def _handle_float(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_FIXED32: raise ValueError(f"Expected fixed32 for float field '{key_name}'")
val, = struct.unpack("<f", reader.read(4))
gen_result(obj, key_name, val, repeated)
def _handle_packed_int64s(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError("Packed int64s expected length_delimited")
total_bytes_len = self.decode_varint(reader)
old_pos = reader.tell()
values = []
while reader.tell() < total_bytes_len + old_pos:
val = self.decode_varint(reader) # need copy here because packed ints are varint
values.append(val - 2**64 if val & (1 << 63) else val)
obj[key_name] = Tensor(values, dtype=dtypes.int64)
def _handle_packed_floats(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError("Packed floats expected length_delimited")
value = self._handle_delimited(reader, use_tensor=True)
obj[key_name] = value.bitcast(dtypes.float32)
def _handle_sub_message(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for sub-message field '{key_name}'")
value = self._handle_delimited(reader, use_tensor=True)
if isinstance(parser_func, str): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func)
elif isinstance(parser_func, tuple): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func[0], parser_func[1])
else: sub_obj = parser_func(BufferedReader(TensorIO(value)))
gen_result(obj, key_name, sub_obj, repeated)
def _parse_external_data(self, obj):
if "external_data" not in obj: raise ValueError("no external_data")
location = None
length = None
offset = 0
for kv in obj["external_data"]:
if kv["key"] == "location": location = kv["value"]
if kv["key"] == "offset": offset = int(kv["value"])
if kv["key"] == "length": length = int(kv["value"])
if location is None: raise ValueError("no location in external_data")
if self.file_path is None:
# get onnx file path from Tensor
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
self.file_path = self.tensor.device[5:]
if not (ext_path := self.file_path.parent.joinpath(location)).exists():
raise Exception(f"external location not exists: {ext_path}, may caused by symbolic link, try passing onnx file path to onnx_load")
else: raise Exception("onnx external_data need the origin file path, try passing onnx file path to onnx_load")
ext_path = self.file_path.parent.joinpath(location)
if not ext_path.exists(): raise Exception(f"external location not exists: {ext_path}")
ext_tensor = Tensor(ext_path)
obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:]
obj["data_location"] = 0

View File

@@ -1,8 +1,7 @@
import time, sys, hashlib
from pathlib import Path
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from tinygrad import Tensor, dtypes, TinyJit
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
from tinygrad.tensor import _from_np_dtype
@@ -12,7 +11,7 @@ from extra.bench_log import BenchEvent, WallTimeEvent
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
if __name__ == "__main__":
onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL))
onnx_model = onnx_load(onnx_path := fetch(OPENPILOT_MODEL))
run_onnx = OnnxRunner(onnx_model)
Tensor.manual_seed(100)

View File

@@ -2,11 +2,10 @@ import csv, pathlib, time
import numpy as np
import torch
torch.set_num_threads(1)
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
import onnxruntime as ort
from onnx2torch import convert
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from tinygrad.helpers import OSX, DEBUG, fetch, getenv
from tinygrad import Tensor, Device
@@ -50,10 +49,10 @@ def benchmark_model(m, devices, validate_outs=False):
CSV = {"model": m}
fn = fetch(MODELS[m])
onnx_model = onnx.load(fn)
onnx_model = onnx_load(fn)
output_names = [out.name for out in onnx_model.graph.output]
excluded = {inp.name for inp in onnx_model.graph.initializer}
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
input_shapes = {inp.name:tuple(x.dim_value if hasattr(x, "dim_value") and x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
#input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
@@ -75,7 +74,7 @@ def benchmark_model(m, devices, validate_outs=False):
# convert model to torch
try:
torch_model = convert(onnx_model)
torch_model = convert(fn)
except Exception as e:
# model conversion failed
print(f"{m:16s}onnx2torch {type(e).__name__:>25}")

View File

@@ -1,4 +1,4 @@
import unittest
import tempfile, unittest
from typing import Any, Tuple
from onnx.backend.base import Backend, BackendRep
import onnx.backend.test
@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
class TinygradModel(BackendRep):
def __init__(self, run_onnx, input_names):
@@ -25,12 +25,15 @@ class TinygradModel(BackendRep):
class TinygradBackend(Backend):
@classmethod
def prepare(cls, model, device):
def prepare(cls, model: onnx.ModelProto, device):
input_all = [x.name for x in model.graph.input]
input_initializer = [x.name for x in model.graph.initializer]
net_feed_input = [x for x in input_all if x not in input_initializer]
print("prepare", cls, device, net_feed_input)
run_onnx = OnnxRunner(model)
with tempfile.NamedTemporaryFile(suffix='.onnx') as f:
onnx.save(model, f.name)
new_model = onnx_load(f.name)
run_onnx = OnnxRunner(new_model)
return TinygradModel(run_onnx, net_feed_input)
@classmethod

View File

@@ -7,7 +7,7 @@ try:
import onnx
except ModuleNotFoundError:
raise unittest.SkipTest("onnx not installed, skipping onnx test")
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
@@ -25,7 +25,7 @@ np.random.seed(1337)
class TestOnnxModel(unittest.TestCase):
def test_benchmark_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
run_onnx = OnnxRunner(onnx_model)
def get_inputs():
np_inputs = {
@@ -69,7 +69,7 @@ class TestOnnxModel(unittest.TestCase):
ps.print_stats(30)
def test_openpilot_model(self):
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
run_onnx = OnnxRunner(onnx_model)
print("got run_onnx")
inputs = {
@@ -93,6 +93,7 @@ class TestOnnxModel(unittest.TestCase):
et = time.monotonic()
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
print(tinygrad_out, torch_out)
np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2)
@@ -119,7 +120,7 @@ class TestOnnxModel(unittest.TestCase):
input_name, input_new)
def _test_model(self, fn, input_name, input_new, debug=False):
onnx_model = onnx.load(fn)
onnx_model = onnx_load(fn)
print("onnx loaded")
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
run_onnx = OnnxRunner(onnx_model)

View File

@@ -67,12 +67,12 @@ def get_quantized_model(sz):
class TestQuantizeOnnxCPU(unittest.TestCase):
def test_quant_128(self, sz=128):
try:
import onnx
import onnx # noqa: F401 # pylint: disable=unused-import
except ImportError:
raise unittest.SkipTest()
from tinygrad.frontend.onnx import OnnxRunner
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
out_file = get_quantized_model(sz)
onnx_model = onnx.load(out_file)
onnx_model = onnx_load(out_file)
run_onnx = OnnxRunner(onnx_model)
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):

View File

@@ -1,5 +1,7 @@
# type: ignore
import sys, pathlib
sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
try: from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import
try:
from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import
from extra.onnx_parser import onnx_load # noqa: F401 # pylint: disable=unused-import
except ImportError as e: raise ImportError("onnx frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e