mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
onnx parser (#10435)
* onnx parser * fix compile, lint * onnx.load -> onnx_load * compatible with ModelProto * fix test external_test_onnx_ops.py * fix tests * fix signed int * reduce to 261 lines * fix TypeProto.Optional * debug for _parse_message, add TypeProto.Sequence, cleanup * onnx_load from Tensor * remove BufferedReader * 174 lines and reduce tensor copy * cleanup * use onnx_load in external_model_benchmark.py * fix qcom test * [onnx] parser support external data --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import sys, onnx, time, pickle
|
||||
import sys, time, pickle
|
||||
from tinygrad import TinyJit, GlobalCounters, fetch, getenv
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from extra.onnx_helpers import get_example_inputs, validate
|
||||
|
||||
def load_onnx_model(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
onnx_model = onnx_load(onnx_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: next(iter(run_onnx({k:v.to(None) for k,v in kwargs.items()}).values())), prune=True, optimize=True)
|
||||
return run_onnx_jit, run_onnx.graph_inputs
|
||||
|
||||
@@ -12,13 +12,13 @@ from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
def compile(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
onnx_model = onnx_load(onnx_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
print("loaded model")
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import onnx
|
||||
from pathlib import Path
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from extra.onnx_helpers import get_example_inputs
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@@ -11,6 +10,6 @@ os.chdir("/tmp")
|
||||
if not Path("yolov8n-seg.onnx").is_file():
|
||||
model = YOLO("yolov8n-seg.pt")
|
||||
model.export(format="onnx", imgsz=[480,640])
|
||||
onnx_model = onnx.load(open("yolov8n-seg.onnx", "rb"))
|
||||
onnx_model = onnx_load(open("yolov8n-seg.onnx", "rb"))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
run_onnx(get_example_inputs(run_onnx.graph_inputs), debug=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import onnx, yaml, tempfile, time, collections, pprint, argparse, json
|
||||
from pathlib import Path
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from extra.onnx import get_onnx_ops
|
||||
from extra.onnx_helpers import validate, get_example_inputs
|
||||
|
||||
@@ -13,7 +13,7 @@ def get_config(root_path: Path):
|
||||
return ret
|
||||
|
||||
def run_huggingface_validate(onnx_model_path, config, rtol, atol):
|
||||
onnx_model = onnx.load(onnx_model_path)
|
||||
onnx_model = onnx_load(onnx_model_path)
|
||||
onnx_runner = OnnxRunner(onnx_model)
|
||||
inputs = get_example_inputs(onnx_runner.graph_inputs, config)
|
||||
validate(onnx_model_path, inputs, rtol=rtol, atol=atol)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Sequence, cast, Literal, Callable
|
||||
import dataclasses, functools, io, math, types
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
|
||||
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.device import is_dtype_supported, Device
|
||||
|
||||
# ***** protobuf parsing ******
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto, helper
|
||||
import numpy as np
|
||||
|
||||
def has_field(onnx_type: TypeProto|SimpleNamespace, field):
|
||||
if isinstance(onnx_type, TypeProto): return onnx_type.HasField(field)
|
||||
return hasattr(onnx_type, field)
|
||||
|
||||
def dtype_parse(onnx_dtype: int) -> DType:
|
||||
supported: dict[int, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
||||
@@ -26,9 +31,10 @@ def dtype_parse(onnx_dtype: int) -> DType:
|
||||
def attribute_parse(onnx_attribute: AttributeProto):
|
||||
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
||||
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
||||
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
||||
AttributeProto.STRING: lambda a: a.s.data().tobytes().decode("utf8") if isinstance(a.s, Tensor) else a.s.decode("utf8"),
|
||||
AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
||||
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
||||
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
||||
AttributeProto.STRINGS: lambda a: tuple(x.data().tobytes().decode("utf8") for x in a.strings)
|
||||
}
|
||||
unsupported = {
|
||||
AttributeProto.UNDEFINED, AttributeProto.GRAPH, AttributeProto.SPARSE_TENSOR, AttributeProto.TYPE_PROTO, AttributeProto.TENSORS,
|
||||
@@ -41,24 +47,35 @@ def attribute_parse(onnx_attribute: AttributeProto):
|
||||
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
|
||||
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
|
||||
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
|
||||
if data := list(onnx_tensor.float_data) or list(onnx_tensor.int32_data) or list(onnx_tensor.int64_data) or list(onnx_tensor.double_data) or \
|
||||
list(onnx_tensor.uint64_data):
|
||||
if len(data) == 1: return Tensor(data[0], dtype=dtype).reshape(shape)
|
||||
return Tensor(data, dtype=dtype).reshape(shape).realize()
|
||||
if onnx_tensor.HasField("raw_data"):
|
||||
np_buffer = np.frombuffer(onnx_tensor.raw_data, dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
|
||||
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
|
||||
return Tensor(np_buffer, dtype=dtype)
|
||||
data = None
|
||||
if len(onnx_tensor.float_data): data = onnx_tensor.float_data
|
||||
elif len(onnx_tensor.int32_data): data = onnx_tensor.int32_data
|
||||
elif len(onnx_tensor.int64_data): data = onnx_tensor.int64_data
|
||||
elif len(onnx_tensor.double_data): data = onnx_tensor.double_data
|
||||
elif len(onnx_tensor.uint64_data): data = onnx_tensor.uint64_data
|
||||
if isinstance(data, Tensor):
|
||||
if len(data) == 1: return Tensor(data.tolist()[0], dtype=dtype).reshape(shape)
|
||||
return data.cast(dtype).reshape(shape).to(Device.DEFAULT)
|
||||
if has_field(onnx_tensor, "raw_data"):
|
||||
if onnx_tensor.data_type == TensorProto.FLOAT16:
|
||||
np_buffer = np.frombuffer(onnx_tensor.raw_data.data().tobytes(),
|
||||
dtype=helper.tensor_dtype_to_np_dtype(onnx_tensor.data_type)).copy().reshape(shape)
|
||||
if np_buffer.size == 1: return Tensor(np_buffer.item(), dtype=dtype).reshape(shape)
|
||||
return Tensor(np_buffer, dtype=dtype)
|
||||
ret = onnx_tensor.raw_data.bitcast(dtype).reshape(shape).to(Device.DEFAULT)
|
||||
if shape == (): ret = Tensor(ret.item(), dtype=dtype).reshape(shape)
|
||||
return ret
|
||||
return Tensor(None)
|
||||
|
||||
def type_parse(onnx_type: TypeProto):
|
||||
elem_type = onnx_type
|
||||
if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"):
|
||||
if has_field(elem_type, "map_type") or has_field(elem_type, "sparse_tensor_type") or has_field(elem_type, "opaque_type"):
|
||||
raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented")
|
||||
if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type
|
||||
if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
||||
if elem_type.HasField("tensor_type"):
|
||||
shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim)
|
||||
if is_optional := has_field(elem_type, "optional_type"): elem_type = elem_type.optional_type.elem_type
|
||||
if is_sequence := has_field(elem_type, "sequence_type"): elem_type = elem_type.sequence_type.elem_type
|
||||
if has_field(elem_type, "tensor_type"):
|
||||
shape = tuple(getattr(d, "dim_param", None) or getattr(d, "dim_value") for d in elem_type.tensor_type.shape.dim) \
|
||||
if has_field(elem_type.tensor_type, "shape") else None # test_identity_sequence_cpu
|
||||
dtype = dtype_parse(elem_type.tensor_type.elem_type)
|
||||
return OnnxValue(shape, dtype, is_optional, is_sequence)
|
||||
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
|
||||
@@ -109,7 +126,7 @@ def to_python_const(t:Any, op:str, idx:int) -> list[ConstType]|ConstType|bytes:
|
||||
debug = int(getenv("DEBUGONNX", "0"))
|
||||
limit = int(getenv("ONNXLIMIT", "-1"))
|
||||
class OnnxRunner:
|
||||
def __init__(self, model: ModelProto):
|
||||
def __init__(self, model: ModelProto|SimpleNamespace):
|
||||
# parse model protobuf
|
||||
self.is_training = any(n.domain in {"ai.onnx.training", "ai.onnx.preview.training"} for n in model.graph.node)
|
||||
self.old_training = Tensor.training
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from extra.onnx import OnnxValue
|
||||
import onnx
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
@@ -47,7 +46,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}):
|
||||
return ret
|
||||
|
||||
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
|
||||
run_onnx = OnnxRunner(onnx.load(onnx_file))
|
||||
run_onnx = OnnxRunner(onnx_load(onnx_file))
|
||||
|
||||
ort_options = ort.SessionOptions()
|
||||
ort_options.log_severity_level = 3
|
||||
|
||||
204
extra/onnx_parser.py
Normal file
204
extra/onnx_parser.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
|
||||
|
||||
import os, pathlib, struct
|
||||
from io import BufferedReader
|
||||
from typing import Tuple, Union
|
||||
from types import SimpleNamespace
|
||||
from tinygrad.nn.state import TensorIO
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
|
||||
# Protobuf Wire Types
|
||||
WIRETYPE_VARINT = 0; WIRETYPE_FIXED64 = 1; WIRETYPE_LENGTH_DELIMITED = 2; WIRETYPE_START_GROUP = 3; WIRETYPE_END_GROUP = 4; WIRETYPE_FIXED32 = 5 # noqa: E702
|
||||
|
||||
# TensorProto.DataType
|
||||
class TensorDataType:
|
||||
UNDEFINED = 0; FLOAT = 1; UINT8 = 2; INT8 = 3; UINT16 = 4; INT16 = 5; INT32 = 6; INT64 = 7 # noqa: E702
|
||||
STRING = 8; BOOL = 9; FLOAT16 = 10; DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; COMPLEX128 = 15; BFLOAT16 = 16 # noqa: E702
|
||||
|
||||
# AttributeProto.AttributeType
|
||||
class AttributeType:
|
||||
UNDEFINED = 0; FLOAT = 1; INT = 2; STRING = 3; TENSOR = 4; GRAPH = 5; SPARSE_TENSOR = 11; TYPE_PROTO = 13; FLOATS = 6; INTS = 7 # noqa: E702
|
||||
STRINGS = 8; TENSORS = 9; GRAPHS = 10; SPARSE_TENSORS = 12; TYPE_PROTOS = 14 # noqa: E702
|
||||
|
||||
class PBType: FLOAT = 1; INT = 2; STRING = 3; FLOATS = 4; INTS = 5; STRINGS = 6; BYTES = 7; SUB = 8 # noqa: E702
|
||||
|
||||
PB_INFOS = {
|
||||
"OperatorSetIdProto": {1: ("domain", PBType.STRING), 2: ("version", PBType.INT)},
|
||||
"StringStringEntryProto": {1: ("key", PBType.STRING), 2: ("value", PBType.STRING)},
|
||||
"TensorProto": {1: ("dims", PBType.INT, True), 2: ("data_type", PBType.INT), 4: ("float_data", PBType.FLOATS),
|
||||
13: ("external_data", PBType.SUB, True, "StringStringEntryProto"), 14: ("data_location", PBType.INT),
|
||||
5: ("int32_data", PBType.INTS), 7: ("int64_data", PBType.INTS), 8: ("name", PBType.STRING), 9: ("raw_data", PBType.BYTES)},
|
||||
"TensorShapeProtoDimension": {1: ("dim_value", PBType.INT), 2: ("dim_param", PBType.STRING)},
|
||||
"TensorShapeProto": {1: ("dim", PBType.SUB, True, "TensorShapeProtoDimension")},
|
||||
"ModelProto": {1: ("ir_version", PBType.INT), 5: ("model_version", PBType.INT),
|
||||
2: ("producer_name", PBType.STRING), 3: ("producer_version", PBType.STRING), 4: ("domain", PBType.STRING), 6: ("doc_string", PBType.STRING),
|
||||
7: ("graph", PBType.SUB, False, ("GraphProto", lambda: {"node": [], "initializer": [], "input": [], "output": [], "value_info": []})),
|
||||
8: ("opset_import",PBType.SUB, True, "OperatorSetIdProto")},
|
||||
"GraphProto": {2: ("name", PBType.STRING), 10: ("doc_string", PBType.STRING),
|
||||
1: ("node", PBType.SUB, True, ("NodeProto", lambda: {"input": [], "output": [], "attribute": [], "domain": None})),
|
||||
5: ("initializer", PBType.SUB, True, ("TensorProto", lambda: {"dims": [], "float_data": [], "int32_data": [], "string_data": [],
|
||||
"int64_data": [], "double_data": [], "uint64_data": []})),
|
||||
11: ("input", PBType.SUB, True, "ValueInfoProto"), 12: ("output", PBType.SUB, True, "ValueInfoProto")},
|
||||
"NodeProto": { 1: ("input", PBType.STRING, True), 2: ("output", PBType.STRING, True), 3: ("name", PBType.STRING),
|
||||
4: ("op_type", PBType.STRING), 6: ("doc_string", PBType.STRING), 7: ("domain", PBType.STRING),
|
||||
5: ("attribute", PBType.SUB, True, ("AttributeProto", lambda: {"floats": [], "ints": [], "strings": []}))},
|
||||
"AttributeProto": {1: ("name", PBType.STRING), 20: ("type", PBType.INT), 3: ("i", PBType.INT), 8: ("ints", PBType.INT, True),
|
||||
2: ("f", PBType.FLOAT), 7: ("floats", PBType.FLOAT, True), 4: ("s", PBType.BYTES), 9: ("strings", PBType.BYTES, True),
|
||||
5:("t", PBType.SUB, False, ("TensorProto", lambda: {"dims": [], "float_data": [], "int32_data": [], "string_data": [], "int64_data": [],
|
||||
"double_data": [], "uint64_data": []}))},
|
||||
"ValueInfoProto": {1: ("name", PBType.STRING), 2: ("type", PBType.SUB, False, "TypeProto"), 3: ("doc_string", PBType.STRING)},
|
||||
"TypeProto": {1: ("tensor_type", PBType.SUB, False, "TypeProtoTensor"), 4: ("sequence_type", PBType.SUB, False, "TypeProtoSequence"),
|
||||
9: ("optional_type", PBType.SUB, False, "TypeProtoOptional"), 6: ("denotation", PBType.STRING)},
|
||||
"TypeProtoSequence": {1: ("elem_type", PBType.SUB, False, "TypeProto")},
|
||||
"TypeProtoOptional": {1: ("elem_type", PBType.SUB, False, "TypeProto")},
|
||||
"TypeProtoTensor": {1: ("elem_type", PBType.INT), 2: ("shape", PBType.SUB, False, ("TensorShapeProto", lambda: {"dim": []}))},
|
||||
}
|
||||
|
||||
def onnx_load(fn: Union[Tensor, str, pathlib.Path], load_external_data: bool=True):
|
||||
parser = OnnxParser(fn, load_external_data)
|
||||
onnx_model = parser.parse()
|
||||
model = dict_to_namespace(onnx_model)
|
||||
return model
|
||||
|
||||
def gen_result(obj: dict, key_name, val, repeated: bool):
|
||||
if repeated: obj.setdefault(key_name, []).append(val)
|
||||
else: obj[key_name] = val
|
||||
|
||||
def dict_to_namespace(d):
|
||||
if isinstance(d, dict): return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
|
||||
elif isinstance(d, list): return [dict_to_namespace(i) for i in d]
|
||||
return d
|
||||
|
||||
class OnnxParser:
|
||||
def __init__(self, inp: Union[Tensor, str, pathlib.Path], load_external_data: bool=True):
|
||||
self.file_path: Union[pathlib.Path, None] = None
|
||||
self.load_external_data = load_external_data
|
||||
if not isinstance(inp, Tensor):
|
||||
self.file_path = pathlib.Path(inp)
|
||||
self.tensor = Tensor(self.file_path)
|
||||
else: self.tensor = inp
|
||||
self.attr_func_dict = { PBType.BYTES: self._handle_bytes, PBType.SUB: self._handle_sub_message, PBType.FLOATS: self._handle_packed_floats,
|
||||
PBType.INT: self._handle_int64, PBType.INTS: self._handle_packed_int64s, PBType.STRING: self._handle_string, PBType.FLOAT: self._handle_float}
|
||||
self.registered_handles = {}
|
||||
for pb_name in PB_INFOS:
|
||||
res = {}
|
||||
for fid, config in PB_INFOS[pb_name].items():
|
||||
parser_fn, repeated = None, False
|
||||
if len(config) == 2: name, attr = config
|
||||
elif len(config) == 3: name, attr, repeated = config
|
||||
elif len(config) == 4: name, attr, repeated, parser_fn = config
|
||||
handler_fn = self.attr_func_dict[attr]
|
||||
def _wrapper_handler(obj, reader, wt, h=handler_fn, n=name, p=parser_fn, r=repeated): return h(obj, n, reader, wt, parser_func=p, repeated=r)
|
||||
_wrapper_handler._debug_info = f"{fid}, {name} => {handler_fn}"
|
||||
res[fid] = _wrapper_handler
|
||||
self.registered_handles[pb_name] = res
|
||||
|
||||
def parse(self):
|
||||
reader = BufferedReader(TensorIO(self.tensor))
|
||||
return self._parse_message(reader, "ModelProto", lambda: {"opset_import": [], "domain": None, "graph": None})
|
||||
|
||||
def decode_varint(self, reader: BufferedReader) -> int:
|
||||
result = 0
|
||||
shift = 0
|
||||
while True:
|
||||
data = reader.read(1)
|
||||
if data == b"": raise EOFError("decode_varint EOF")
|
||||
result |= (data[0] & 0x7F) << shift
|
||||
if not (data[0] & 0x80): return result
|
||||
shift += 7
|
||||
if shift >= 70: raise ValueError("Varint too long")
|
||||
|
||||
def skip_field_value(self, reader: BufferedReader, wire_type):
|
||||
if wire_type == WIRETYPE_VARINT: self.decode_varint(reader)
|
||||
elif wire_type == WIRETYPE_FIXED64: reader.seek(8, os.SEEK_CUR)
|
||||
elif wire_type == WIRETYPE_FIXED32: reader.seek(4, os.SEEK_CUR)
|
||||
elif wire_type == WIRETYPE_LENGTH_DELIMITED: reader.seek(self.decode_varint(reader), os.SEEK_CUR)
|
||||
else: raise ValueError(f"Unknown wire type: {wire_type}")
|
||||
|
||||
def _parse_message(self, reader, message_field_handlers_name, initial_obj_factory=lambda: {}):
|
||||
message_field_handlers = self.registered_handles[message_field_handlers_name]
|
||||
obj = initial_obj_factory()
|
||||
while True:
|
||||
try:
|
||||
tag_val = self.decode_varint(reader)
|
||||
field_number = tag_val >> 3
|
||||
wire_type = tag_val & 0x07
|
||||
if handler := message_field_handlers.get(field_number):
|
||||
handler(obj, reader, wire_type)
|
||||
else: self.skip_field_value(reader, wire_type)
|
||||
except EOFError: break
|
||||
if message_field_handlers_name == "TensorProto" and self.load_external_data and obj.get("data_location", 0) == 1: self._parse_external_data(obj)
|
||||
return obj
|
||||
|
||||
def _handle_delimited(self, reader:BufferedReader, use_tensor=False) -> Tuple[bytes, Tensor]:
|
||||
str_len = self.decode_varint(reader)
|
||||
if not use_tensor: return reader.read(str_len)
|
||||
res = reader.raw._tensor[reader.tell():(reader.tell()+str_len)]
|
||||
reader.seek(str_len, os.SEEK_CUR)
|
||||
return res
|
||||
|
||||
def _handle_string(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for string field '{key_name}'")
|
||||
value = self._handle_delimited(reader)
|
||||
gen_result(obj, key_name, value.decode("utf-8"), repeated)
|
||||
|
||||
def _handle_bytes(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for bytes field '{key_name}'")
|
||||
value = self._handle_delimited(reader, use_tensor=True)
|
||||
gen_result(obj, key_name, value, repeated)
|
||||
|
||||
def _handle_int64(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_VARINT: raise ValueError(f"Expected varint for int64 field '{key_name}'")
|
||||
val = self.decode_varint(reader)
|
||||
gen_result(obj, key_name, val - 2**64 if val & (1 << 63) else val, repeated)
|
||||
|
||||
def _handle_float(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_FIXED32: raise ValueError(f"Expected fixed32 for float field '{key_name}'")
|
||||
val, = struct.unpack("<f", reader.read(4))
|
||||
gen_result(obj, key_name, val, repeated)
|
||||
|
||||
def _handle_packed_int64s(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError("Packed int64s expected length_delimited")
|
||||
total_bytes_len = self.decode_varint(reader)
|
||||
old_pos = reader.tell()
|
||||
values = []
|
||||
while reader.tell() < total_bytes_len + old_pos:
|
||||
val = self.decode_varint(reader) # need copy here because packed ints are varint
|
||||
values.append(val - 2**64 if val & (1 << 63) else val)
|
||||
obj[key_name] = Tensor(values, dtype=dtypes.int64)
|
||||
|
||||
def _handle_packed_floats(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError("Packed floats expected length_delimited")
|
||||
value = self._handle_delimited(reader, use_tensor=True)
|
||||
obj[key_name] = value.bitcast(dtypes.float32)
|
||||
|
||||
def _handle_sub_message(self, obj, key_name, reader, wire_type, parser_func=None, repeated=False):
|
||||
if wire_type != WIRETYPE_LENGTH_DELIMITED: raise ValueError(f"Expected length-delimited for sub-message field '{key_name}'")
|
||||
value = self._handle_delimited(reader, use_tensor=True)
|
||||
if isinstance(parser_func, str): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func)
|
||||
elif isinstance(parser_func, tuple): sub_obj = self._parse_message(BufferedReader(TensorIO(value)), parser_func[0], parser_func[1])
|
||||
else: sub_obj = parser_func(BufferedReader(TensorIO(value)))
|
||||
gen_result(obj, key_name, sub_obj, repeated)
|
||||
|
||||
def _parse_external_data(self, obj):
|
||||
if "external_data" not in obj: raise ValueError("no external_data")
|
||||
location = None
|
||||
length = None
|
||||
offset = 0
|
||||
for kv in obj["external_data"]:
|
||||
if kv["key"] == "location": location = kv["value"]
|
||||
if kv["key"] == "offset": offset = int(kv["value"])
|
||||
if kv["key"] == "length": length = int(kv["value"])
|
||||
if location is None: raise ValueError("no location in external_data")
|
||||
if self.file_path is None:
|
||||
# get onnx file path from Tensor
|
||||
if isinstance(self.tensor.device, str) and self.tensor.device.startswith("DISK:"):
|
||||
self.file_path = self.tensor.device[5:]
|
||||
if not (ext_path := self.file_path.parent.joinpath(location)).exists():
|
||||
raise Exception(f"external location not exists: {ext_path}, may caused by symbolic link, try passing onnx file path to onnx_load")
|
||||
else: raise Exception("onnx external_data need the origin file path, try passing onnx file path to onnx_load")
|
||||
ext_path = self.file_path.parent.joinpath(location)
|
||||
if not ext_path.exists(): raise Exception(f"external location not exists: {ext_path}")
|
||||
ext_tensor = Tensor(ext_path)
|
||||
obj["raw_data"] = ext_tensor[offset:offset+length] if length is not None else ext_tensor[offset:]
|
||||
obj["data_location"] = 0
|
||||
@@ -1,8 +1,7 @@
|
||||
import time, sys, hashlib
|
||||
from pathlib import Path
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from tinygrad import Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import IMAGE, GlobalCounters, fetch, colored, getenv, trange
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
@@ -12,7 +11,7 @@ from extra.bench_log import BenchEvent, WallTimeEvent
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_model = onnx.load(onnx_path := fetch(OPENPILOT_MODEL))
|
||||
onnx_model = onnx_load(onnx_path := fetch(OPENPILOT_MODEL))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
Tensor.manual_seed(100)
|
||||
|
||||
9
test/external/external_model_benchmark.py
vendored
9
test/external/external_model_benchmark.py
vendored
@@ -2,11 +2,10 @@ import csv, pathlib, time
|
||||
import numpy as np
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
import onnxruntime as ort
|
||||
from onnx2torch import convert
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from tinygrad.helpers import OSX, DEBUG, fetch, getenv
|
||||
from tinygrad import Tensor, Device
|
||||
|
||||
@@ -50,10 +49,10 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
CSV = {"model": m}
|
||||
|
||||
fn = fetch(MODELS[m])
|
||||
onnx_model = onnx.load(fn)
|
||||
onnx_model = onnx_load(fn)
|
||||
output_names = [out.name for out in onnx_model.graph.output]
|
||||
excluded = {inp.name for inp in onnx_model.graph.initializer}
|
||||
input_shapes = {inp.name:tuple(x.dim_value if x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
|
||||
input_shapes = {inp.name:tuple(x.dim_value if hasattr(x, "dim_value") and x.dim_value != 0 else 1 for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input if inp.name not in excluded} # noqa: E501
|
||||
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input if inp.name not in excluded}
|
||||
#input_types = {k:v if v!=np.float16 else np.float32 for k,v in input_types.items()} # cast
|
||||
np_inputs = {k:torch.randn(shp).numpy().astype(input_types[k]) for k,shp in input_shapes.items()}
|
||||
@@ -75,7 +74,7 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
|
||||
# convert model to torch
|
||||
try:
|
||||
torch_model = convert(onnx_model)
|
||||
torch_model = convert(fn)
|
||||
except Exception as e:
|
||||
# model conversion failed
|
||||
print(f"{m:16s}onnx2torch {type(e).__name__:>25}")
|
||||
|
||||
11
test/external/external_test_onnx_backend.py
vendored
11
test/external/external_test_onnx_backend.py
vendored
@@ -1,4 +1,4 @@
|
||||
import unittest
|
||||
import tempfile, unittest
|
||||
from typing import Any, Tuple
|
||||
from onnx.backend.base import Backend, BackendRep
|
||||
import onnx.backend.test
|
||||
@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
|
||||
# pip3 install tabulate
|
||||
pytest_plugins = 'onnx.backend.test.report',
|
||||
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
|
||||
class TinygradModel(BackendRep):
|
||||
def __init__(self, run_onnx, input_names):
|
||||
@@ -25,12 +25,15 @@ class TinygradModel(BackendRep):
|
||||
|
||||
class TinygradBackend(Backend):
|
||||
@classmethod
|
||||
def prepare(cls, model, device):
|
||||
def prepare(cls, model: onnx.ModelProto, device):
|
||||
input_all = [x.name for x in model.graph.input]
|
||||
input_initializer = [x.name for x in model.graph.initializer]
|
||||
net_feed_input = [x for x in input_all if x not in input_initializer]
|
||||
print("prepare", cls, device, net_feed_input)
|
||||
run_onnx = OnnxRunner(model)
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx') as f:
|
||||
onnx.save(model, f.name)
|
||||
new_model = onnx_load(f.name)
|
||||
run_onnx = OnnxRunner(new_model)
|
||||
return TinygradModel(run_onnx, net_feed_input)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -7,7 +7,7 @@ try:
|
||||
import onnx
|
||||
except ModuleNotFoundError:
|
||||
raise unittest.SkipTest("onnx not installed, skipping onnx test")
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import CI, fetch, temp
|
||||
|
||||
@@ -25,7 +25,7 @@ np.random.seed(1337)
|
||||
|
||||
class TestOnnxModel(unittest.TestCase):
|
||||
def test_benchmark_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
def get_inputs():
|
||||
np_inputs = {
|
||||
@@ -69,7 +69,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
ps.print_stats(30)
|
||||
|
||||
def test_openpilot_model(self):
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
onnx_model = onnx_load(fetch(OPENPILOT_MODEL))
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
print("got run_onnx")
|
||||
inputs = {
|
||||
@@ -93,6 +93,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
et = time.monotonic()
|
||||
print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue")
|
||||
|
||||
onnx_model = onnx.load(fetch(OPENPILOT_MODEL))
|
||||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(tinygrad_out, torch_out, atol=1e-4, rtol=1e-2)
|
||||
@@ -119,7 +120,7 @@ class TestOnnxModel(unittest.TestCase):
|
||||
input_name, input_new)
|
||||
|
||||
def _test_model(self, fn, input_name, input_new, debug=False):
|
||||
onnx_model = onnx.load(fn)
|
||||
onnx_model = onnx_load(fn)
|
||||
print("onnx loaded")
|
||||
from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
|
||||
@@ -67,12 +67,12 @@ def get_quantized_model(sz):
|
||||
class TestQuantizeOnnxCPU(unittest.TestCase):
|
||||
def test_quant_128(self, sz=128):
|
||||
try:
|
||||
import onnx
|
||||
import onnx # noqa: F401 # pylint: disable=unused-import
|
||||
except ImportError:
|
||||
raise unittest.SkipTest()
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
from tinygrad.frontend.onnx import OnnxRunner, onnx_load
|
||||
out_file = get_quantized_model(sz)
|
||||
onnx_model = onnx.load(out_file)
|
||||
onnx_model = onnx_load(out_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
|
||||
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# type: ignore
|
||||
import sys, pathlib
|
||||
sys.path.append(pathlib.Path(__file__).parent.parent.as_posix())
|
||||
try: from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import
|
||||
try:
|
||||
from extra.onnx import OnnxRunner # noqa: F401 # pylint: disable=unused-import
|
||||
from extra.onnx_parser import onnx_load # noqa: F401 # pylint: disable=unused-import
|
||||
except ImportError as e: raise ImportError("onnx frontend not in release\nTo fix, install tinygrad from a git checkout with pip install -e .") from e
|
||||
Reference in New Issue
Block a user