mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
failed test case for onnx IF with jit (#14235)
silently fails now since onnx treats IF cond as a const
This commit is contained in:
37
test/external/external_test_onnx_ops.py
vendored
37
test/external/external_test_onnx_ops.py
vendored
@@ -137,6 +137,43 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
def test_if_different_shapes_not_broadcastable(self):
|
||||
self._test_if(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32), np.array([[6, 5, 4, 3, 2, 1]]).astype(np.float32))
|
||||
|
||||
def test_if_jit_different_shapes(self):
|
||||
# TODO: If with different output shapes and non-const condition should raise
|
||||
# When shapes differ, Python selection evaluates condition at graph build time, breaking JIT
|
||||
from tinygrad import TinyJit
|
||||
# then: x+1 shape (3,), else: x[:2]+1 shape (2,)
|
||||
x_input = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, (3,))
|
||||
then_out = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, (3,))
|
||||
then_body = onnx.helper.make_graph([
|
||||
onnx.helper.make_node("Constant", [], ["one"], value=onnx.numpy_helper.from_array(np.array(1, dtype=np.float32))),
|
||||
onnx.helper.make_node("Add", ["x", "one"], ["res"])], "then_body", [x_input], [then_out])
|
||||
else_out = onnx.helper.make_tensor_value_info("res", onnx.TensorProto.FLOAT, (2,))
|
||||
else_body = onnx.helper.make_graph([
|
||||
onnx.helper.make_node("Constant", [], ["starts"], value=onnx.numpy_helper.from_array(np.array([0], dtype=np.int32))),
|
||||
onnx.helper.make_node("Constant", [], ["ends"], value=onnx.numpy_helper.from_array(np.array([2], dtype=np.int32))),
|
||||
onnx.helper.make_node("Constant", [], ["one"], value=onnx.numpy_helper.from_array(np.array(1, dtype=np.float32))),
|
||||
onnx.helper.make_node("Slice", ["x", "starts", "ends"], ["x2"]),
|
||||
onnx.helper.make_node("Add", ["x2", "one"], ["res"])], "else_body", [x_input], [else_out])
|
||||
|
||||
cond_input = onnx.helper.make_tensor_value_info("cond", onnx.TensorProto.BOOL, (1,))
|
||||
main_x = onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, (3,))
|
||||
graph = onnx.helper.make_graph([onnx.helper.make_node("If", ["cond"], ["res"], then_branch=then_body, else_branch=else_body)],
|
||||
"test", [cond_input, main_x], [onnx.helper.make_empty_tensor_value_info("res")])
|
||||
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 22)])
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp:
|
||||
onnx.save(model, tmp.name)
|
||||
runner = OnnxRunner(tmp.name)
|
||||
|
||||
@TinyJit
|
||||
def run_if(cond, x): return runner({"cond": cond, "x": x})["res"]
|
||||
|
||||
x = Tensor([1.0, 2.0, 3.0])
|
||||
self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4]) # x + 1
|
||||
self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4])
|
||||
self.assertEqual(run_if(Tensor([True]), x).tolist(), [2, 3, 4])
|
||||
self.assertEqual(run_if(Tensor([False]), x).tolist(), [2, 3, 4]) # wrong! should be [2, 3]
|
||||
|
||||
def test_resize_downsample_scales_linear_align_corners(self):
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-131
|
||||
X = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8]]]], dtype=np.float32)
|
||||
|
||||
Reference in New Issue
Block a user