use TensorProto enum in onnx dtype mapping [run_process_replay] (#6151)

This commit is contained in:
chenyu
2024-08-17 17:58:40 -04:00
committed by GitHub
parent f7950fc2b6
commit 7c9c8ce22f

View File

@@ -38,13 +38,16 @@ def is_dtype_supported(dtype, device: str = Device.DEFAULT):
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
# src: onnx/mapping.py
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
# NOTE: 17, 18, 19, 20 are float8, 10 is half
DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64,
9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16,
17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float}
# TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
# TODO: use dtypes.float16 for FLOAT16
DTYPE_MAP = {
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float, TensorProto.FLOAT8E4M3FNUZ:dtypes.float,
TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
}
onnx_ops = importlib.import_module('extra.onnx_ops')
@@ -72,7 +75,8 @@ def get_run_onnx(onnx_model: ModelProto):
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
if inp.data_type not in DTYPE_MAP:
raise NotImplementedError(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))