mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: inputset of test_mlir_converter for dot operation
This commit is contained in:
@@ -213,13 +213,28 @@ def datagen(*args):
|
||||
},
|
||||
(range(0, 8),),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
inputset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
compiler.round_trip(mlir_result)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args_dict, args_ranges",
|
||||
[
|
||||
(
|
||||
dot,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
(range(0, 4), range(0, 4)),
|
||||
),
|
||||
(
|
||||
dot,
|
||||
@@ -227,14 +242,22 @@ def datagen(*args):
|
||||
"x": ClearTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
"y": EncryptedTensor(Integer(64, is_signed=False), shape=(4,)),
|
||||
},
|
||||
(range(0, 8), range(0, 8)),
|
||||
(range(0, 4), range(0, 4)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_mlir_converter(func, args_dict, args_ranges):
|
||||
def test_mlir_converter_dot_between_vectors(func, args_dict, args_ranges):
|
||||
"""Test the conversion to MLIR by calling the parser from the compiler"""
|
||||
inputset = datagen(*args_ranges)
|
||||
result_graph = compile_numpy_function_into_op_graph(func, args_dict, inputset)
|
||||
assert len(args_dict["x"].shape) == 1
|
||||
assert len(args_dict["y"].shape) == 1
|
||||
|
||||
n = args_dict["x"].shape[0]
|
||||
|
||||
result_graph = compile_numpy_function_into_op_graph(
|
||||
func,
|
||||
args_dict,
|
||||
(([data[0]] * n, [data[1]] * n) for data in datagen(*args_ranges)),
|
||||
)
|
||||
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
|
||||
mlir_result = converter.convert(result_graph)
|
||||
# testing that this doesn't raise an error
|
||||
|
||||
Reference in New Issue
Block a user