diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 194d734988..76a5042385 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -137,6 +137,43 @@ class TestMainOnnxOps(TestOnnxOps): def test_if_different_shapes_not_broadcastable(self): self._test_if(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32), np.array([[6, 5, 4, 3, 2, 1]]).astype(np.float32)) + def test_if_jit_different_shapes(self): + # TODO: If with different output shapes and non-const condition should raise + # When shapes differ, Python selection evaluates condition at graph build time, breaking JIT + from tinygrad import TinyJit + # then: x+1 shape (3,), else: x[:2]+1 shape (2,) + x_input = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, (3,)) + then_out = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, (3,)) + then_body = onnx.helper.make_graph([ + onnx.helper.make_node("Constant", [], ["one"], value=onnx.numpy_helper.from_array(np.array(1, dtype=np.float32))), + onnx.helper.make_node("Add", ["x", "one"], ["res"])], "then_body", [x_input], [then_out]) + else_out = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, (2,)) + else_body = onnx.helper.make_graph([ + onnx.helper.make_node("Constant", [], ["starts"], value=onnx.numpy_helper.from_array(np.array([0], dtype=np.int32))), + onnx.helper.make_node("Constant", [], ["ends"], value=onnx.numpy_helper.from_array(np.array([2], dtype=np.int32))), + onnx.helper.make_node("Constant", [], ["one"], value=onnx.numpy_helper.from_array(np.array(1, dtype=np.float32))), + onnx.helper.make_node("Slice", ["x", "starts", "ends"], ["x2"]), + onnx.helper.make_node("Add", ["x2", "one"], ["res"])], "else_body", [x_input], [else_out]) + + cond_input = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, (1,)) + main_x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, (3,)) + graph = onnx.helper.make_graph([onnx.helper.make_node("If", ["cond"], ["res"], then_branch=then_body, else_branch=else_body)], + "test", [cond_input, main_x], [onnx.helper.make_empty_tensor_value_info("res")]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 22)]) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: + onnx.save(model, tmp.name) + runner = OnnxRunner(tmp.name) + + @TinyJit + def run_if(cond, x): return runner({"cond": cond, "x": x})["res"] + + x = Tensor([1.0, 2.0, 3.0]) + self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4]) # x + 1 + self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4]) + self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4]) + self.assertEqual(run_if(Tensor([False]), x).tolist(), [2, 3, 4]) # wrong! should be [2, 3] + def test_resize_downsample_scales_linear_align_corners(self): # https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-131 X = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], dtype=np.float32)