From e4a06116ecb093be73505620c6669ea51a555729 Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 23 Sep 2021 15:50:15 +0300 Subject: [PATCH] fix: inputset of test_mlir_converter for dot operation --- tests/common/mlir/test_mlir_converter.py | 33 ++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/tests/common/mlir/test_mlir_converter.py b/tests/common/mlir/test_mlir_converter.py index fec21c6e1..2ab61b91a 100644 --- a/tests/common/mlir/test_mlir_converter.py +++ b/tests/common/mlir/test_mlir_converter.py @@ -213,13 +213,28 @@ def datagen(*args): }, (range(0, 8),), ), + ], +) +def test_mlir_converter(func, args_dict, args_ranges): + """Test the conversion to MLIR by calling the parser from the compiler""" + inputset = datagen(*args_ranges) + result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset) + converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) + mlir_result = converter.convert(result_graph) + # testing that this doesn't raise an error + compiler.round_trip(mlir_result) + + +@pytest.mark.parametrize( + "func, args_dict, args_ranges", + [ ( dot, { "x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), "y": ClearTensor(Integer(64, is_signed=False), shape=(4,)), }, - (range(0, 8), range(0, 8)), + (range(0, 4), range(0, 4)), ), ( dot, @@ -227,14 +242,22 @@ def datagen(*args): "x": ClearTensor(Integer(64, is_signed=False), shape=(4,)), "y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)), }, - (range(0, 8), range(0, 8)), + (range(0, 4), range(0, 4)), ), ], ) -def test_mlir_converter(func, args_dict, args_ranges): +def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges): """Test the conversion to MLIR by calling the parser from the compiler""" - inputset = datagen(*args_ranges) - result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset) + assert len(args_dict["x"].shape) == 1 + assert len(args_dict["y"].shape) == 1 + + n = args_dict["x"].shape[0] + + result_graph = compile_numpy_function_into_op_graph( + func, + args_dict, + (([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)), + ) converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(result_graph) # testing that this doesn't raise an error