fix some type error in onnx [run_process_replay] (#6153)

This commit is contained in:
chenyu
2024-08-17 19:54:20 -04:00
committed by GitHub
parent 7c9c8ce22f
commit 9db2d0d5c6

View File

@@ -1,20 +1,19 @@
from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from typing import List, Dict, Union
import importlib
from functools import lru_cache
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from tinygrad.dtype import ConstType
from typing import List, Dict, Union
from tinygrad.dtype import ConstType, DType
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
# for onnx < 1.13
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
def tensor_dtype_to_np_dtype(tensor_dtype:int) -> np.dtype: return TENSOR_TYPE_TO_NP_TYPE[tensor_dtype]
cache_misses = 0
@lru_cache(None)
@@ -41,7 +40,7 @@ def is_dtype_supported(dtype, device: str = Device.DEFAULT):
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
# TODO: use dtypes.float16 for FLOAT16
DTYPE_MAP = {
DTYPE_MAP: Dict[TensorProto.DataType, DType] = {
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
@@ -68,11 +67,11 @@ def get_run_onnx(onnx_model: ModelProto):
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type not in DTYPE_MAP:
@@ -81,8 +80,8 @@ def get_run_onnx(onnx_model: ModelProto):
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy(),
requires_grad=False).reshape(tuple(inp.dims))
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
return Tensor(data, requires_grad=False).reshape(tuple(inp.dims))
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
@@ -94,9 +93,8 @@ def get_run_onnx(onnx_model: ModelProto):
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
elif a.type == AttributeProto.GRAPH: raise NotImplementedError(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise RuntimeError(f"can't parse {a.type} {a}")
tensors: Dict[str, Tensor] = {}
@@ -108,35 +106,37 @@ def get_run_onnx(onnx_model: ModelProto):
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
attribute_dict[num] = {x.name:attribute_parse(x) for x in n.attribute}
if n.domain: domain = n.domain
onnx_model_version = onnx_model.opset_import[0].version
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor] = {}
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = type_parse(inp.type)
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
elif isinstance(inputs[inp.name], list):
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
for model_input in onnx_model.graph.input:
name = model_input.name
if name in tensors: continue
shape = type_parse(model_input.type)
if name in inputs:
if isinstance(inputs[name], Tensor):
input_tensors[name] = inputs[name]
elif isinstance(inputs[name], list):
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
input_tensors[name] = Tensor(inputs[name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
else:
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
if shape: # if only input_tensor is not variable type
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
ts = input_tensors[name]
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
raise RuntimeError(f"no data for {name} with shape {shape}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
@@ -213,7 +213,7 @@ def get_run_onnx(onnx_model: ModelProto):
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
raise NotImplementedError(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"