fix: inputset of test_mlir_converter for dot operation

This commit is contained in:
Umut
2021-09-23 15:50:15 +03:00
parent b971f6b913
commit e4a06116ec

View File

@@ -213,13 +213,28 @@ def datagen(*args):
},
(range(0, 8),),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error
compiler.round_trip(mlir_result)
@pytest.mark.parametrize(
"func, args_dict, args_ranges",
[
(
dot,
{
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
(range(0, 4), range(0, 4)),
),
(
dot,
@@ -227,14 +242,22 @@ def datagen(*args):
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
},
(range(0, 8), range(0, 8)),
(range(0, 4), range(0, 4)),
),
],
)
def test_mlir_converter(func, args_dict, args_ranges):
def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
"""Test the conversion to MLIR by calling the parser from the compiler"""
inputset = datagen(*args_ranges)
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
assert len(args_dict["x"].shape) == 1
assert len(args_dict["y"].shape) == 1
n = args_dict["x"].shape[0]
result_graph = compile_numpy_function_into_op_graph(
func,
args_dict,
(([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)),
)
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(result_graph)
# testing that this doesn't raise an error