Add fallback dtype to ONNX (#10788)

* start

* still need the float16 workaround in

* tiny nit for correctness

* idk hacks, I need to understand this device stuff better

* no-op?

* remove that assert for true nooooooop

* add fallback_context
This commit is contained in:
geohotstan
2025-06-13 08:39:21 +08:00
committed by GitHub
parent dcd1928f29
commit 806b68c2b3

View File

@@ -1,6 +1,6 @@
from types import SimpleNamespace
from typing import Any, Sequence, cast, Literal, Callable
import dataclasses, functools, io, math, types
import dataclasses, functools, io, math, types, warnings
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort
from tinygrad.dtype import DType, ConstType, dtypes, ImageDType
@@ -14,7 +14,7 @@ def has_field(onnx_type: TypeProto|SimpleNamespace, field):
if isinstance(onnx_type, TypeProto): return onnx_type.HasField(field)
return hasattr(onnx_type, field)
def dtype_parse(onnx_dtype: int) -> DType:
def dtype_parse(onnx_dtype: int, fallback_context: str | None = None) -> DType:
supported: dict[int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
@@ -26,7 +26,13 @@ def dtype_parse(onnx_dtype: int) -> DType:
TensorProto.FLOAT8E5M2, TensorProto.FLOAT8E5M2FNUZ, TensorProto.UINT4, TensorProto.INT4
}
if onnx_dtype in unsupported: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return supported[onnx_dtype] if is_dtype_supported(supported[onnx_dtype]) else dtypes.float
if is_dtype_supported(dtype := supported[onnx_dtype]): return dtype
# if fallback_context is provided, we can fall back to a default dtype
if fallback_context is not None:
default_dtype = dtypes.float
warnings.warn(f"dtype {dtype} on {Device.DEFAULT} from {fallback_context} is not supported, falling back to {default_dtype}")
return default_dtype
raise RuntimeError(f"dtype {dtype} on device {Device.DEFAULT} is not supported")
def attribute_parse(onnx_attribute: AttributeProto):
supported: dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
@@ -46,7 +52,7 @@ def attribute_parse(onnx_attribute: AttributeProto):
def buffer_parse(onnx_tensor: TensorProto) -> Tensor:
if onnx_tensor.string_data: raise NotImplementedError("Parsing for buffer with string data is not implemented.")
dtype, shape = dtype_parse(onnx_tensor.data_type), tuple(onnx_tensor.dims)
dtype, shape = dtype_parse(onnx_tensor.data_type, "buffer parse"), tuple(onnx_tensor.dims)
data = None
if len(onnx_tensor.float_data): data = onnx_tensor.float_data
elif len(onnx_tensor.int32_data): data = onnx_tensor.int32_data
@@ -76,7 +82,7 @@ def type_parse(onnx_type: TypeProto):
if has_field(elem_type, "tensor_type"):
shape = tuple(getattr(d, "dim_param", None) or getattr(d, "dim_value") for d in elem_type.tensor_type.shape.dim) \
if has_field(elem_type.tensor_type, "shape") else None # test_identity_sequence_cpu
dtype = dtype_parse(elem_type.tensor_type.elem_type)
dtype = dtype_parse(elem_type.tensor_type.elem_type, "input type spec parse")
return OnnxValue(shape, dtype, is_optional, is_sequence)
raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}")
@@ -145,15 +151,15 @@ class OnnxRunner:
if spec.is_optional and value is None: return None
# TODO: need true float16 for dtype checking
if spec.is_sequence:
if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type")
if not isinstance(value, Sequence): raise RuntimeError(f"input {name} received {value}, expected a sequence type")
sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous")
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for input {name} sequence must be homogeneous")
return sequence
tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value
for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)):
if isinstance(onnx_dim, str):
onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input))
if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
if user_dim_input != onnx_dim: raise RuntimeError(f"input {name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.")
return tensor
def _dispatch_op(self, op, inps, opts):
@@ -284,7 +290,7 @@ def get_onnx_ops():
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype, "EyeLike op") if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
@@ -338,7 +344,7 @@ def get_onnx_ops():
# ***** Casting Ops *****
# TODO: saturate
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to, "Cast op"))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# ***** Reduce Ops *****
@@ -731,7 +737,9 @@ def get_onnx_ops():
# ***** Quantization Ops *****
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
if isinstance(y_zero_point, Tensor): out_dtype = y_zero_point.dtype
elif output_dtype != 0: out_dtype = dtype_parse(output_dtype, "QuantizeLinear op")
else: out_dtype = dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
if out_dtype == dtypes.uchar:
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff