mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use TensorProto enum in onnx dtype mapping [run_process_replay] (#6151)
This commit is contained in:
@@ -38,13 +38,16 @@ def is_dtype_supported(dtype, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
# src: onnx/mapping.py
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
|
||||
# NOTE: 17, 18, 19, 20 are float8, 10 is half
|
||||
DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64,
|
||||
9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16,
|
||||
17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float}
|
||||
# TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse
|
||||
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
|
||||
# TODO: use dtypes.float16 for FLOAT16
|
||||
DTYPE_MAP = {
|
||||
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
|
||||
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
|
||||
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
|
||||
TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float, TensorProto.FLOAT8E4M3FNUZ:dtypes.float,
|
||||
TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
|
||||
}
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
|
||||
@@ -72,7 +75,8 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
|
||||
if inp.data_type not in DTYPE_MAP:
|
||||
raise NotImplementedError(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
|
||||
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
|
||||
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
||||
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
|
||||
|
||||
Reference in New Issue
Block a user