mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix some type error in onnx [run_process_replay] (#6153)
This commit is contained in:
@@ -1,20 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||||
from typing import List, Dict, Union
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, OSX
|
||||
from tinygrad.dtype import ConstType
|
||||
from typing import List, Dict, Union
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
except ImportError:
|
||||
# for onnx < 1.13
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
|
||||
def tensor_dtype_to_np_dtype(tensor_dtype:int) -> np.dtype: return TENSOR_TYPE_TO_NP_TYPE[tensor_dtype]
|
||||
|
||||
cache_misses = 0
|
||||
@lru_cache(None)
|
||||
@@ -41,7 +40,7 @@ def is_dtype_supported(dtype, device: str = Device.DEFAULT):
|
||||
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
|
||||
# TODO: use dtypes.float16 for FLOAT16
|
||||
DTYPE_MAP = {
|
||||
DTYPE_MAP: Dict[TensorProto.DataType, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
|
||||
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
|
||||
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
|
||||
@@ -68,11 +67,11 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
elif attr == 'sequence_type':
|
||||
type_proto = getattr(type_proto, attr).elem_type
|
||||
ret.append(1)
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
||||
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
||||
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
|
||||
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type not in DTYPE_MAP:
|
||||
@@ -81,8 +80,8 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
||||
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
|
||||
if len(inp.raw_data) > 0:
|
||||
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy(),
|
||||
requires_grad=False).reshape(tuple(inp.dims))
|
||||
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
|
||||
return Tensor(data, requires_grad=False).reshape(tuple(inp.dims))
|
||||
return Tensor(None, requires_grad=False)
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
@@ -94,9 +93,8 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
|
||||
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
|
||||
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
|
||||
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
|
||||
else: raise Exception(f"can't parse {a.type} {a}")
|
||||
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
|
||||
elif a.type == AttributeProto.GRAPH: raise NotImplementedError(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
|
||||
else: raise RuntimeError(f"can't parse {a.type} {a}")
|
||||
|
||||
tensors: Dict[str, Tensor] = {}
|
||||
|
||||
@@ -108,35 +106,37 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
attribute_dict = {}
|
||||
domain = ""
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = attribute_to_dict(n.attribute)
|
||||
attribute_dict[num] = {x.name:attribute_parse(x) for x in n.attribute}
|
||||
if n.domain: domain = n.domain
|
||||
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
input_tensors: Dict[str,Tensor] = {}
|
||||
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
|
||||
intermediate_tensors: Dict[str,Tensor] = {}
|
||||
output_tensor_names = [x.name for x in onnx_model.graph.output]
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = type_parse(inp.type)
|
||||
if inp.name in inputs:
|
||||
if isinstance(inputs[inp.name], Tensor):
|
||||
input_tensors[inp.name] = inputs[inp.name]
|
||||
elif isinstance(inputs[inp.name], list):
|
||||
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
|
||||
for model_input in onnx_model.graph.input:
|
||||
name = model_input.name
|
||||
if name in tensors: continue
|
||||
shape = type_parse(model_input.type)
|
||||
if name in inputs:
|
||||
if isinstance(inputs[name], Tensor):
|
||||
input_tensors[name] = inputs[name]
|
||||
elif isinstance(inputs[name], list):
|
||||
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
|
||||
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
|
||||
else:
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
|
||||
if shape: # if only input_tensor is not variable type
|
||||
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
ts = input_tensors[name]
|
||||
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
|
||||
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
raise RuntimeError(f"no data for {name} with shape {shape}")
|
||||
|
||||
def fetch_tensor(x: str):
|
||||
if x in tensors: return tensors[x]
|
||||
@@ -213,7 +213,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
ret = real_fxn(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
raise NotImplementedError(f"op_type {n.op_type} not supported")
|
||||
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
|
||||
Reference in New Issue
Block a user