feat: an option to show MLIR

closes #224
This commit is contained in:
Benoit Chevallier-Mames
2021-08-27 09:55:45 +02:00
committed by Benoit Chevallier
parent 61daa49e9d
commit 809ce28b38
2 changed files with 36 additions and 0 deletions

View File

@@ -132,6 +132,7 @@ def compile_numpy_function(
dataset: Iterator[Tuple[Any, ...]],
compilation_configuration: Optional[CompilationConfiguration] = None,
compilation_artifacts: Optional[CompilationArtifacts] = None,
show_mlir: bool = False,
) -> CompilerEngine:
"""Main API of hnumpy, to be able to compile an homomorphic program.
@@ -146,6 +147,8 @@ def compile_numpy_function(
during compilation
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
during compilation
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo
Returns:
CompilerEngine: engine to run and debug the compiled graph
@@ -171,6 +174,9 @@ def compile_numpy_function(
converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS)
mlir_result = converter.convert(op_graph)
if show_mlir:
print(f"MLIR which is going to be compiled: \n{mlir_result}")
# Compile the MLIR representation
engine = CompilerEngine()
engine.compile_fhe(mlir_result)

View File

@@ -236,3 +236,33 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str):
f"\n==================\nExpected {ref_graph_str}"
f"\n==================\n"
)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
],
)
def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names):
"""Test show_mlir option"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names
}
compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
show_mlir=True,
)