From 809ce28b3820a2d208f1679d4665dff836ecb370 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Fri, 27 Aug 2021 09:55:45 +0200 Subject: [PATCH] feat: an option to show MLIR closes #224 --- hdk/hnumpy/compile.py | 6 ++++++ tests/hnumpy/test_compile.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/hdk/hnumpy/compile.py b/hdk/hnumpy/compile.py index 2dbe045d4..98259e0ba 100644 --- a/hdk/hnumpy/compile.py +++ b/hdk/hnumpy/compile.py @@ -132,6 +132,7 @@ def compile_numpy_function( dataset: Iterator[Tuple[Any, ...]], compilation_configuration: Optional[CompilationConfiguration] = None, compilation_artifacts: Optional[CompilationArtifacts] = None, + show_mlir: bool = False, ) -> CompilerEngine: """Main API of hnumpy, to be able to compile an homomorphic program. @@ -146,6 +147,8 @@ def compile_numpy_function( during compilation compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill during compilation + show_mlir (bool): if set, the MLIR produced by the converter and which is going + to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo Returns: CompilerEngine: engine to run and debug the compiled graph @@ -171,6 +174,9 @@ def compile_numpy_function( converter = MLIRConverter(V0_OPSET_CONVERSION_FUNCTIONS) mlir_result = converter.convert(op_graph) + if show_mlir: + print(f"MLIR which is going to be compiled: \n{mlir_result}") + # Compile the MLIR representation engine = CompilerEngine() engine.compile_fhe(mlir_result) diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 0a62b19a0..5f04714d9 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -236,3 +236,33 @@ def test_compile_function_with_dot(function, params, shape, ref_graph_str): f"\n==================\nExpected {ref_graph_str}" f"\n==================\n" ) + + +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(lambda x: x + 64, ((0, 10),), ["x"]), + pytest.param(lambda x: x * 3, ((0, 40),), ["x"]), + pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]), + pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]), + pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]), + pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]), + ], +) +def test_compile_with_show_mlir(function, input_ranges, list_of_arg_names): + """Test show_mlir option""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names + } + + compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + show_mlir=True, + )