diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index 3e204dd49..b26120ece 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -611,6 +611,79 @@ def prepare_op_graph_for_mlir(op_graph: OPGraph): hack_offset_negative_inputs_to_lookup_tables(op_graph) +def _compile_op_graph_to_fhe_circuit_internal( + op_graph: OPGraph, show_mlir: bool, compilation_artifacts: CompilationArtifacts +) -> FHECircuit: + """Compile the OPGraph to an FHECircuit. + + Args: + op_graph (OPGraph): the OPGraph to compile. + show_mlir (bool): determine whether we print the mlir string. + compilation_artifacts (CompilationArtifacts): Artifacts object to fill + during compilation + + Returns: + FHECircuit: the compiled FHECircuit + """ + prepare_op_graph_for_mlir(op_graph) + + # Convert graph to an MLIR representation + converter = NPMLIRConverter() + mlir_result = converter.convert(op_graph) + + # Show MLIR representation if requested + if show_mlir: + print(f"MLIR which is going to be compiled: \n{mlir_result}") + + # Add MLIR representation as an artifact + compilation_artifacts.add_final_operation_graph_mlir(mlir_result) + + # Compile the MLIR representation + engine = CompilerEngine() + engine.compile_fhe(mlir_result) + + return FHECircuit(op_graph, engine) + + +def compile_op_graph_to_fhe_circuit( + op_graph: OPGraph, + show_mlir: bool, + compilation_configuration: Optional[CompilationConfiguration] = None, + compilation_artifacts: Optional[CompilationArtifacts] = None, +) -> FHECircuit: + """Compile the OPGraph to an FHECircuit. + + Args: + op_graph (OPGraph): the OPGraph to compile. + show_mlir (bool): determine whether we print the mlir string. + compilation_configuration (Optional[CompilationConfiguration]): Configuration object to use + during compilation + compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill + during compilation + + Returns: + FHECircuit: the compiled circuit and the compiled FHECircuit + """ + + ( + compilation_configuration, + compilation_artifacts, + ) = sanitize_compilation_configuration_and_artifacts( + compilation_configuration, compilation_artifacts + ) + + def compilation_function(): + return _compile_op_graph_to_fhe_circuit_internal(op_graph, show_mlir, compilation_artifacts) + + result = run_compilation_function_with_error_management( + compilation_function, compilation_configuration, compilation_artifacts + ) + + # for mypy + assert isinstance(result, FHECircuit) + return result + + def _compile_numpy_function_internal( function_to_compile: Callable, function_parameters: Dict[str, BaseValue], @@ -648,24 +721,11 @@ def _compile_numpy_function_internal( compilation_artifacts, ) - prepare_op_graph_for_mlir(op_graph) + fhe_circuit = _compile_op_graph_to_fhe_circuit_internal( + op_graph, show_mlir, compilation_artifacts + ) - # Convert graph to an MLIR representation - converter = NPMLIRConverter() - mlir_result = converter.convert(op_graph) - - # Show MLIR representation if requested - if show_mlir: - print(f"MLIR which is going to be compiled: \n{mlir_result}") - - # Add MLIR representation as an artifact - compilation_artifacts.add_final_operation_graph_mlir(mlir_result) - - # Compile the MLIR representation - engine = CompilerEngine() - engine.compile_fhe(mlir_result) - - return FHECircuit(op_graph, engine) + return fhe_circuit def compile_numpy_function( diff --git a/concrete/numpy/np_fhe_compiler.py b/concrete/numpy/np_fhe_compiler.py index fe90b219c..d97e6fd1d 100644 --- a/concrete/numpy/np_fhe_compiler.py +++ b/concrete/numpy/np_fhe_compiler.py @@ -6,10 +6,15 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer +from ..common.fhe_circuit import FHECircuit from ..common.operator_graph import OPGraph from ..common.representation.intermediate import IntermediateNode from ..common.values import BaseValue -from .compile import compile_numpy_function_into_op_graph, measure_op_graph_bounds_and_update +from .compile import ( + compile_numpy_function_into_op_graph, + compile_op_graph_to_fhe_circuit, + measure_op_graph_bounds_and_update, +) from .np_dtypes_helpers import get_base_value_for_numpy_or_python_constant_data @@ -197,3 +202,32 @@ class NPFHECompiler: if isinstance(dtype := (input_.dtype), Integer): dtype.bit_width = 128 dtype.is_signed = True + + def get_compiled_fhe_circuit(self, show_mlir: bool = False) -> FHECircuit: + """Return a compiled FHECircuit if the instance was evaluated on an inputset. + + Args: + show_mlir (bool, optional): if set, the MLIR produced by the converter and which is + going to be sent to the compiler backend is shown on the screen, e.g., for debugging + or demo. Defaults to False. + + Raises: + RuntimeError: raised if no inputset was passed to the instance. + + Returns: + FHECircuit: the compiled FHECircuit + """ + self._eval_on_current_inputset() + + if self._op_graph is None: + raise RuntimeError( + "Requested FHECircuit but no OPGraph was compiled. " + f"Did you forget to evaluate {self.__class__.__name__} over an inputset?" + ) + + return compile_op_graph_to_fhe_circuit( + self._op_graph, + show_mlir, + self.compilation_configuration, + self.compilation_artifacts, + ) diff --git a/tests/numpy/test_compile_user_friendly_api.py b/tests/numpy/test_compile_user_friendly_api.py index 02fbf4c6d..10061313b 100644 --- a/tests/numpy/test_compile_user_friendly_api.py +++ b/tests/numpy/test_compile_user_friendly_api.py @@ -22,17 +22,20 @@ def complicated_topology(x, y): @pytest.mark.parametrize("input_shape", [(), (3, 1, 2)]) -def test_np_fhe_compiler(input_shape, default_compilation_configuration): +def test_np_fhe_compiler_op_graph(input_shape, default_compilation_configuration): """Test NPFHECompiler in two subtests.""" - subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration) - subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration) + subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration) + subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration) -def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configuration): +def subtest_np_fhe_compiler_1_input_op_graph(input_shape, default_compilation_configuration): """test for NPFHECompiler on one input function""" + def function_to_compile(x): + return complicated_topology(x, 0) + compiler = NPFHECompiler( - lambda x: complicated_topology(x, 0), + function_to_compile, {"x": "encrypted"}, default_compilation_configuration, ) @@ -45,7 +48,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati for i in numpy.arange(5): i = numpy.ones(input_shape, dtype=numpy.int64) * i - assert numpy.array_equal(compiler(i), complicated_topology(i, 0)) + assert numpy.array_equal(compiler(i), function_to_compile(i)) # For coverage, check that we flush the inputset when we query the OPGraph current_op_graph = compiler.op_graph @@ -57,7 +60,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati # Continue a bit more for i in numpy.arange(5, 10): i = numpy.ones(input_shape, dtype=numpy.int64) * i - assert numpy.array_equal(compiler(i), complicated_topology(i, 0)) + assert numpy.array_equal(compiler(i), function_to_compile(i)) if input_shape == (): assert ( @@ -103,7 +106,7 @@ def subtest_np_fhe_compiler_1_input(input_shape, default_compilation_configurati ), got -def subtest_np_fhe_compiler_2_inputs(input_shape, default_compilation_configuration): +def subtest_np_fhe_compiler_2_inputs_op_graph(input_shape, default_compilation_configuration): """test for NPFHECompiler on two inputs function""" compiler = NPFHECompiler( @@ -199,17 +202,51 @@ def test_np_fhe_compiler_auto_flush( default_compilation_configuration, ): """Test the auto flush of NPFHECompiler once the inputset is 128 elements.""" + + def function_to_compile(x): + return x // 2 + compiler = NPFHECompiler( - lambda x: x // 2, + function_to_compile, {"x": "encrypted"}, default_compilation_configuration, ) for i in numpy.arange(inputset_len): - assert numpy.array_equal(compiler(i), i // 2) + assert numpy.array_equal(compiler(i), function_to_compile(i)) # Check the inputset was properly flushed assert ( len(compiler._current_inputset) # pylint: disable=protected-access == expected_remaining_inputset_len ) + + +def test_np_fhe_compiler_full_compilation(default_compilation_configuration): + """Test the case where we generate an FHE circuit.""" + + def function_to_compile(x): + return x + 42 + + compiler = NPFHECompiler( + function_to_compile, + {"x": "encrypted"}, + default_compilation_configuration, + ) + + # For coverage + with pytest.raises(RuntimeError) as excinfo: + compiler.get_compiled_fhe_circuit() + + assert str(excinfo.value) == ( + "Requested FHECircuit but no OPGraph was compiled. " + "Did you forget to evaluate NPFHECompiler over an inputset?" + ) + + for i in numpy.arange(64): + assert numpy.array_equal(compiler(i), function_to_compile(i)) + + fhe_circuit = compiler.get_compiled_fhe_circuit() + + for i in range(64): + assert fhe_circuit.run(i) == function_to_compile(i)