From ac713e04db4e7989740674ab52e00dbc745793fe Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:44:53 +0800 Subject: [PATCH] ONNX add output shape validation (#9720) * add output shape validation and remove support for sequence_type * nit better err msg * add sequence_type back * improve err msg * Revert "improve err msg" This reverts commit dc9eaea4bb7bb7934434ab2cde0d43caaa0d672d. * Revert "add sequence_type back" This reverts commit 288170b2d9fbcb8000358f7ea79ce8d8862863be. * do explicit shape equality * small nit --- extra/onnx.py | 41 ++++++++------- extra/onnx_helpers.py | 2 +- test/external/external_test_onnx_backend.py | 6 +++ test/external/external_test_onnx_ops.py | 56 +++++++++++++-------- 4 files changed, 63 insertions(+), 42 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index c9ebe23b1a..58da854212 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -1,7 +1,7 @@ -from typing import Any, Sequence, cast, Literal, Callable +from typing import Any, cast, Literal, Callable import dataclasses, functools, io, math, types from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr -from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort +from tinygrad.helpers import getenv, DEBUG, prod, flatten, make_tuple, argsort from tinygrad.dtype import DType, ConstType, dtypes, ImageDType from tinygrad.device import is_dtype_supported @@ -56,11 +56,10 @@ def type_parse(onnx_type: TypeProto): if elem_type.HasField("map_type") or elem_type.HasField("sparse_tensor_type") or elem_type.HasField("opaque_type"): raise NotImplementedError("parsing for map_type, sparse_tensor_type and opaque_type are not implemented") if is_optional := elem_type.HasField("optional_type"): elem_type = elem_type.optional_type.elem_type - if is_sequence := elem_type.HasField("sequence_type"): elem_type = elem_type.sequence_type.elem_type if elem_type.HasField("tensor_type"): shape = tuple(d.dim_param or d.dim_value for d in elem_type.tensor_type.shape.dim) dtype = dtype_parse(elem_type.tensor_type.elem_type) - return OnnxValue(shape, dtype, is_optional, is_sequence) + return OnnxValue(shape, dtype, is_optional) raise RuntimeError(f"TypeProto was not parsed properly: {onnx_type=}") # ***** onnx spec ***** @@ -69,7 +68,6 @@ class OnnxValue: shape: tuple[str|int, ...] dtype: DType is_optional: bool - is_sequence: bool @dataclasses.dataclass(frozen=True) class OnnxNode: @@ -117,7 +115,7 @@ class OnnxRunner: Tensor.no_grad = False if self.is_training else True self.graph_values = {"": None, **{x.name:buffer_parse(x) for x in model.graph.initializer}} self.graph_inputs = {x.name:type_parse(x.type) for x in model.graph.input if x.name not in self.graph_values} - self.graph_outputs = tuple(x.name for x in model.graph.output) + self.graph_outputs = {x.name:type_parse(x.type) for x in model.graph.output} self.graph_nodes = tuple(OnnxNode(num, n.op_type, tuple(n.input), tuple(n.output), {x.name:attribute_parse(x) for x in n.attribute}) for num,n in enumerate(model.graph.node)) self.opset_version = model.opset_import[0].version @@ -125,20 +123,25 @@ class OnnxRunner: self.onnx_ops = onnx_ops + def _validate_shape(self, name: str, value: Tensor, spec: OnnxValue): + # update new variable dims + self.variable_dims.update({sd:vd for sd,vd in zip(spec.shape, value.shape) if isinstance(sd, str) and sd not in self.variable_dims}) + # resolve dynamic shape + expected_shape = tuple(self.variable_dims[sd] if isinstance(sd, str) else sd for sd in spec.shape) + assert value.shape == expected_shape, f"'{name}' has wrong shape" + def _parse_input(self, name: str, value: Any, spec: OnnxValue): if spec.is_optional and value is None: return None - # TODO: need true float16 for dtype checking - if spec.is_sequence: - if not isinstance(value, Sequence): raise RuntimeError(f"{name} received {value}, expected a sequence type") - sequence = [Tensor(v, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(v, Tensor) else v for v in value] - if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"Shapes for {name} sequence must be homogeneous") - return sequence - tensor = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) if not isinstance(value, Tensor) else value - for dim, (onnx_dim, user_dim_input) in enumerate(zip(spec.shape, tensor.shape, strict=True)): - if isinstance(onnx_dim, str): - onnx_dim = self.variable_dims[onnx_dim] if onnx_dim in self.variable_dims else self.variable_dims.setdefault(onnx_dim, int(user_dim_input)) - if user_dim_input != onnx_dim: raise RuntimeError(f"{name} has mismatch on {dim=}. Expected {onnx_dim}, received {user_dim_input}.") - return tensor + if value is None: raise RuntimeError(f"'{name}' is not marked as optional, but received a None value") + if not isinstance(value, Tensor): value = Tensor(value, dtype=spec.dtype, requires_grad=self.is_training) + self._validate_shape(name, value, spec) + return value + + def _parse_output(self, name: str): + value, spec = self.graph_values[name], self.graph_outputs[name] + if not isinstance(value, Tensor): return value + self._validate_shape(name, value, spec) + return value def _dispatch_op(self, op, inps, opts): if op in self.onnx_ops: @@ -176,7 +179,7 @@ class OnnxRunner: Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad return {name:self.graph_values[name] for name in node.outputs} Tensor.training, Tensor.no_grad = self.old_training, self.old_no_grad - return {name:self.graph_values[name] for name in self.graph_outputs} + return {name:self._parse_output(name) for name in self.graph_outputs} #################### ##### ONNX OPS ##### diff --git a/extra/onnx_helpers.py b/extra/onnx_helpers.py index ab020fdb23..dc4d5d10bb 100644 --- a/extra/onnx_helpers.py +++ b/extra/onnx_helpers.py @@ -40,7 +40,7 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue], config={}): ret: dict[str, Tensor] = {} for name, spec in graph_inputs.items(): - assert not spec.is_optional and not spec.is_sequence, "only allow tensor input for now" + assert not spec.is_optional, "only allow tensor input for now" shape = _get_shape(spec.shape) value = _get_value(name, shape, spec.dtype) ret.update({name:value}) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index e5fa02f2d5..39321d6e79 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -175,6 +175,12 @@ backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # anti backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string +# no support for sequence +backend_test.exclude('test_identity_opt_cpu') +backend_test.exclude('test_identity_sequence_cpu') +backend_test.exclude('test_optional_get_element_optional_sequence_cpu') +backend_test.exclude('test_optional_get_element_sequence_cpu') + backend_test.exclude('test_scatternd_min_cpu') # min not yet supported backend_test.exclude('test_scatternd_max_cpu') # max not yet supported diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 2f0528168d..5e37d15231 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -11,11 +11,11 @@ from extra.onnx_helpers import validate class TestOnnxOps(unittest.TestCase): DOMAIN = None - def helper_build_model(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str]): - onnx_inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()] - onnx_outputs = [onnx.helper.make_empty_tensor_value_info(name) for name in outs] + def helper_build_model(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:dict[str, np.ndarray]): + inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()] + outputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in outs.items()] nodes = [onnx.helper.make_node(op, list(inps), list(outs), domain=self.DOMAIN, **opts)] - graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", onnx_inputs, onnx_outputs) + graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", inputs, outputs) model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}") return model @@ -30,7 +30,7 @@ class TestMainOnnxOps(TestOnnxOps): def test_reshape(self): inputs = {"in": np.arange(6, dtype=np.float32), "shape": np.array([2,3], dtype=np.int64)} attributes = {} - outputs = ["out"] + outputs = {"out": np.empty((2,3), dtype=np.float32)} self.helper_test_single_op("Reshape", inputs, attributes, outputs) def test_conv(self): @@ -41,7 +41,7 @@ class TestMainOnnxOps(TestOnnxOps): "b": np.random.randn(1152).astype(np.float32) } attributes = {'auto_pad': 'VALID', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (14, 14), 'strides': (14, 14)} - outputs = ["y"] + outputs = {"y": np.empty((1, 1152, 27, 27), dtype=np.float32)} self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4) def test_gather(self): @@ -51,7 +51,7 @@ class TestMainOnnxOps(TestOnnxOps): "indices": np.array(-2, dtype=np.int64), } attributes = {'axis': 1} - outputs = ["y"] + outputs = {"y": np.empty((1, 3), dtype=np.float32)} self.helper_test_single_op("Gather", inputs, attributes, outputs) def test_maxunpool(self): @@ -61,7 +61,7 @@ class TestMainOnnxOps(TestOnnxOps): output_shape = np.array((1, 1, 5, 5), dtype=np.int64) inputs = {"x": xT, "indices": xI, "output_shape": output_shape} attributes = {"kernel_shape": [2, 2], "strides": [2, 2]} - outputs = ["y"] + outputs = {"y": np.empty((1, 1, 5, 5), dtype=np.float32)} self.helper_test_single_op("MaxUnpool", inputs, attributes, outputs) def test_isinf(self): @@ -70,7 +70,7 @@ class TestMainOnnxOps(TestOnnxOps): x = np.array([-1.2, np.nan, np.inf, 2.8, -np.inf, np.inf], dtype=np.float32) inputs = {"x": x} attributes = {"detect_negative":1, "detect_positive":1} - outputs = ["y"] + outputs = {"y": np.empty((6,), dtype=np.bool_)} model = self.helper_build_model("IsInf", inputs, attributes, outputs) outputs = OnnxRunner(model)(inputs) assert outputs["y"].dtype is dtypes.bool @@ -87,7 +87,8 @@ class TestMainOnnxOps(TestOnnxOps): "y_scale": np.array(case["scale"], dtype=np.float32), "y_zero_point": np.array(case["qzero_point"], dtype=case["qdtype"]) } - self.helper_test_single_op("QuantizeLinear", inputs, {}, ["y"]) + outputs = {"y": np.empty_like(inputs["x"], dtype=case["qdtype"])} + self.helper_test_single_op("QuantizeLinear", inputs, {}, outputs) def test_dynamic_quantize_linear(self): test_cases = [ @@ -100,7 +101,12 @@ class TestMainOnnxOps(TestOnnxOps): ] for case in test_cases: with self.subTest(test_case=case["name"]): - self.helper_test_single_op("DynamicQuantizeLinear", {"x": case["x"]}, {}, ["y", "y_scale", "y_zero_point"]) + outputs = { + "y": np.empty_like(case["x"], dtype=np.uint8), + "y_scale": np.empty((), dtype=np.float32), + "y_zero_point": np.empty((), dtype=np.uint8) + } + self.helper_test_single_op("DynamicQuantizeLinear", {"x": case["x"]}, {}, outputs) def test_qlinear_conv(self): for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]: @@ -119,7 +125,7 @@ class TestMainOnnxOps(TestOnnxOps): "b": b } attributes = {'auto_pad': 'NOTSET', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (3, 3), 'pads': (1, 1, 1, 1), 'strides': (2, 2)} - outputs = ["out"] + outputs = {"out": np.empty((1, 32, 112, 112), dtype=dtype)} self.helper_test_single_op("QLinearConv", inputs, attributes, outputs, atol=1) # occasionally inaccurate def test_qlinear_matmul(self): @@ -137,7 +143,7 @@ class TestMainOnnxOps(TestOnnxOps): "Y_zero_point": np.array(zero_point, dtype=dtype) } attributes = {} - outputs = ["Y"] + outputs = {"Y": np.empty((10, 10), dtype=dtype)} self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs) for name,val in (("round_half_down_to_even", 1), ("round_half_up_to_even", 3)): @@ -153,7 +159,7 @@ class TestMainOnnxOps(TestOnnxOps): "Y_zero_point": np.array(0, dtype=np.int8) } attributes = {} - outputs = ["Y"] + outputs = {"Y": np.empty((), dtype=np.int8)} self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs) class TestContribOnnxOps(TestOnnxOps): @@ -202,7 +208,8 @@ class TestContribOnnxOps(TestOnnxOps): with self.subTest(f"test_attention_{i}"): inps = {**base_inps, **extra_inps} opts = {**base_opts, **extra_opts} - outputs = ["output", "present"] if "past" in inps else ["output"] + if "qkv_hidden_sizes" in opts: outputs = {"output": np.empty((batch_size, seq_len, 128), dtype=np.float32)} + else: outputs = {"output": np.empty((batch_size, seq_len, hidden_size), dtype=np.float32)} self.helper_test_single_op("Attention", inps, opts, outputs, atol=1e-4) def test_skip_layer_normalization(self): @@ -219,7 +226,12 @@ class TestContribOnnxOps(TestOnnxOps): if has_beta: inputs["beta"] = np.random.randn(hidden_size).astype(np.float32) if has_bias: inputs["bias"] = np.random.randn(hidden_size).astype(np.float32) attributes = {"epsilon": 1e-12} - outputs = ["output", "mean", "inv_std_var", "input_skip_bias_sum"] + outputs = { + "output": np.empty(shape, dtype=np.float32), + "mean": np.empty(0, dtype=np.float32), + "inv_std_var": np.empty(0, dtype=np.float32), + "input_skip_bias_sum": np.empty(shape, dtype=np.float32) + } self.helper_test_single_op("SkipLayerNormalization", inputs, attributes, outputs) def test_bias_gelu(self): @@ -229,7 +241,7 @@ class TestContribOnnxOps(TestOnnxOps): "B": np.random.randn(shape[-1]).astype(np.float32) } attributes = {} - outputs = ["C"] + outputs = {"C": np.empty(shape, dtype=np.float32)} self.helper_test_single_op("BiasGelu", inputs, attributes, outputs) def test_qlinear_add(self): @@ -247,7 +259,7 @@ class TestContribOnnxOps(TestOnnxOps): "C_zero_point": np.array(zero_point, dtype=dtype) } attributes = {} - outputs = ["C"] + outputs = {"C": np.empty((10, 10), dtype=dtype)} self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1) # TODO: look into why this is inaccurate with self.subTest(test_case="round_half_to_even"): @@ -262,7 +274,7 @@ class TestContribOnnxOps(TestOnnxOps): "C_zero_point": np.array(0, dtype=np.int8) } attributes = {} - outputs = ["C"] + outputs = {"C": np.empty((4,), dtype=np.int8)} self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs) def test_qlinear_mul(self): @@ -280,7 +292,7 @@ class TestContribOnnxOps(TestOnnxOps): "C_zero_point": np.array(zero_point, dtype=dtype) } attributes = {} - outputs = ["C"] + outputs = {"C": np.empty((10, 10), dtype=dtype)} self.helper_test_single_op("QLinearMul", inputs, attributes, outputs) with self.subTest(test_case="round_half_to_even"): @@ -295,7 +307,7 @@ class TestContribOnnxOps(TestOnnxOps): "C_zero_point": np.array(0, dtype=np.int8) } attributes = {} - outputs = ["C"] + outputs = {"C": np.empty((4,), dtype=np.int8)} self.helper_test_single_op("QLinearMul", inputs, attributes, outputs) def test_qlinear_global_average_pool(self): @@ -310,7 +322,7 @@ class TestContribOnnxOps(TestOnnxOps): "y_zero_point": np.array(zero_point, dtype=dtype) } attributes = {"channels_last": 0} - outputs = ["C"] + outputs = {"C": np.empty((1, 3, 1, 1), dtype=dtype)} self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs) if __name__ == "__main__":