From dbfde466bcc51708f1d7fbade382839d31996073 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 9 Sep 2022 23:18:48 +0200 Subject: [PATCH] feat(python): Add compilation feedback to the python bindings --- .../concretelang-c/Support/CompilerEngine.h | 11 ++++- .../concretelang/Support/LibrarySupport.h | 10 +++-- compiler/lib/Bindings/Python/CMakeLists.txt | 1 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 30 ++++++++++---- .../Python/concrete/compiler/__init__.py | 1 + .../concrete/compiler/compilation_feedback.py | 40 +++++++++++++++++++ .../Python/concrete/compiler/jit_support.py | 23 +++++++++++ .../concrete/compiler/library_support.py | 26 ++++++++++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 23 ++++++++++- compiler/tests/python/test_compilation.py | 21 ++++------ 10 files changed, 160 insertions(+), 26 deletions(-) create mode 100644 compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 11ff45092..c162c3cda 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -47,6 +47,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters jit_load_client_parameters(JITSupport_C support, mlir::concretelang::JitCompilationResult &); +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +jit_load_compilation_feedback(JITSupport_C support, + mlir::concretelang::JitCompilationResult &); + MLIR_CAPI_EXPORTED std::shared_ptr jit_load_server_lambda(JITSupport_C support, mlir::concretelang::JitCompilationResult &); @@ -66,7 +70,8 @@ typedef struct LibrarySupport_C LibrarySupport_C; MLIR_CAPI_EXPORTED LibrarySupport_C library_support(const char *outputPath, const char *runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCppHeader); + bool generateClientParameters, bool generateCompilationFeedback, + bool generateCppHeader); MLIR_CAPI_EXPORTED std::unique_ptr library_compile(LibrarySupport_C support, const char *module, @@ -76,6 +81,10 @@ MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters library_load_client_parameters(LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &); +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +library_load_compilation_feedback( + LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &); + MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda library_load_server_lambda(LibrarySupport_C support, mlir::concretelang::LibraryCompilationResult &); diff --git a/compiler/include/concretelang/Support/LibrarySupport.h b/compiler/include/concretelang/Support/LibrarySupport.h index fe6be0a3f..6f9fde9ed 100644 --- a/compiler/include/concretelang/Support/LibrarySupport.h +++ b/compiler/include/concretelang/Support/LibrarySupport.h @@ -37,11 +37,13 @@ public: LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "", bool generateSharedLib = true, bool generateStaticLib = true, bool generateClientParameters = true, + bool generateCompilationFeedback = true, bool generateCppHeader = true) : outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath), generateSharedLib(generateSharedLib), generateStaticLib(generateStaticLib), generateClientParameters(generateClientParameters), + generateCompilationFeedback(generateCompilationFeedback), generateCppHeader(generateCppHeader) {} llvm::Expected> @@ -52,9 +54,10 @@ public: engine.setCompilationOptions(options); // Compile to a library - auto library = engine.compile(program, outputPath, runtimeLibraryPath, - generateSharedLib, generateStaticLib, - generateClientParameters, generateCppHeader); + auto library = engine.compile( + program, outputPath, runtimeLibraryPath, generateSharedLib, + generateStaticLib, generateClientParameters, + generateCompilationFeedback, generateCppHeader); if (auto err = library.takeError()) { return std::move(err); } @@ -136,6 +139,7 @@ private: bool generateSharedLib; bool generateStaticLib; bool generateClientParameters; + bool generateCompilationFeedback; bool generateCppHeader; }; diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index 3051ff6fa..0738ec662 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -29,6 +29,7 @@ declare_mlir_python_sources(ConcretelangBindingsPythonSources concrete/compiler/__init__.py concrete/compiler/client_parameters.py concrete/compiler/client_support.py + concrete/compiler/compilation_feedback.py concrete/compiler/compilation_options.py concrete/compiler/jit_compilation_result.py concrete/compiler/jit_support.py diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index d05db18d8..cafbd8b76 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -73,7 +73,22 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( }); pybind11::class_( - m, "CompilationFeedback"); + m, "CompilationFeedback") + .def_readonly("complexity", + &mlir::concretelang::CompilationFeedback::complexity) + .def_readonly( + "total_secret_keys_size", + &mlir::concretelang::CompilationFeedback::totalSecretKeysSize) + .def_readonly( + "total_bootstrap_keys_size", + &mlir::concretelang::CompilationFeedback::totalBootstrapKeysSize) + .def_readonly( + "total_keyswitch_keys_size", + &mlir::concretelang::CompilationFeedback::totalKeyswitchKeysSize) + .def_readonly("total_inputs_size", + &mlir::concretelang::CompilationFeedback::totalInputsSize) + .def_readonly("total_output_size", + &mlir::concretelang::CompilationFeedback::totalOutputsSize); pybind11::class_( m, "JITCompilationResult"); @@ -97,7 +112,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def("load_compilation_feedback", [](JITSupport_C &support, mlir::concretelang::JitCompilationResult &result) { - return jit_load_client_parameters(support, result); + return jit_load_compilation_feedback(support, result); }) .def( "load_server_lambda", @@ -127,11 +142,12 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def(pybind11::init( [](std::string outputPath, std::string runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCppHeader) { - return library_support(outputPath.c_str(), - runtimeLibraryPath.c_str(), - generateSharedLib, generateStaticLib, - generateClientParameters, generateCppHeader); + bool generateClientParameters, bool generateCompilationFeedback, + bool generateCppHeader) { + return library_support( + outputPath.c_str(), runtimeLibraryPath.c_str(), + generateSharedLib, generateStaticLib, generateClientParameters, + generateCompilationFeedback, generateCppHeader); })) .def("compile", [](LibrarySupport_C &support, std::string mlir_program, diff --git a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index d270b2953..b9b6d063b 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -15,6 +15,7 @@ from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip from .compilation_options import CompilationOptions from .key_set_cache import KeySetCache from .client_parameters import ClientParameters +from .compilation_feedback import CompilationFeedback from .key_set import KeySet from .public_result import PublicResult from .public_arguments import PublicArguments diff --git a/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py new file mode 100644 index 000000000..5096d6e3c --- /dev/null +++ b/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -0,0 +1,40 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. + +"""Compilation feedback.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + CompilationFeedback as _CompilationFeedback, +) + +# pylint: enable=no-name-in-module,import-error + +from .wrapper import WrapperCpp + + +class CompilationFeedback(WrapperCpp): + """CompilationFeedback is a set of hint computed by the compiler engine.""" + + def __init__(self, compilation_feedback: _CompilationFeedback): + """Wrap the native Cpp object. + + Args: + compilation_feeback (_CompilationFeedback): object to wrap + + Raises: + TypeError: if compilation_feedback is not of type _CompilationFeedback + """ + if not isinstance(compilation_feedback, _CompilationFeedback): + raise TypeError( + f"compilation_feedback must be of type _CompilationFeedback, not {type(compilation_feedback)}" + ) + + self.complexity = compilation_feedback.complexity + self.total_secret_keys_size = compilation_feedback.total_secret_keys_size + self.total_bootstrap_keys_size = compilation_feedback.total_bootstrap_keys_size + self.total_keyswitch_keys_size = compilation_feedback.total_keyswitch_keys_size + self.total_inputs_size = compilation_feedback.total_inputs_size + self.total_output_size = compilation_feedback.total_output_size + + super().__init__(compilation_feedback) diff --git a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py index 8711a1e49..e39508717 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/jit_support.py @@ -19,6 +19,7 @@ from .utils import lookup_runtime_lib from .compilation_options import CompilationOptions from .jit_compilation_result import JITCompilationResult from .client_parameters import ClientParameters +from .compilation_feedback import CompilationFeedback from .jit_lambda import JITLambda from .public_arguments import PublicArguments from .public_result import PublicResult @@ -121,6 +122,28 @@ class JITSupport(WrapperCpp): self.cpp().load_client_parameters(compilation_result.cpp()) ) + def load_compilation_feedback( + self, compilation_result: JITCompilationResult + ) -> CompilationFeedback: + """Load the compilation feedback from the JIT compilation result. + + Args: + compilation_result (JITCompilationResult): result of the JIT compilation + + Raises: + TypeError: if compilation_result is not of type JITCompilationResult + + Returns: + CompilationFeedback: the compilation feedback for the compiled program + """ + if not isinstance(compilation_result, JITCompilationResult): + raise TypeError( + f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}" + ) + return CompilationFeedback.wrap( + self.cpp().load_compilation_feedback(compilation_result.cpp()) + ) + def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda: """Load the JITLambda from the JIT compilation result. diff --git a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py index 8e1bf3e38..5c6445e9e 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/library_support.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/library_support.py @@ -21,6 +21,7 @@ from .public_arguments import PublicArguments from .library_lambda import LibraryLambda from .public_result import PublicResult from .client_parameters import ClientParameters +from .compilation_feedback import CompilationFeedback from .wrapper import WrapperCpp from .utils import lookup_runtime_lib from .evaluation_keys import EvaluationKeys @@ -71,6 +72,7 @@ class LibrarySupport(WrapperCpp): generateSharedLib: bool = True, generateStaticLib: bool = False, generateClientParameters: bool = True, + generateCompilationFeedback: bool = True, generateCppHeader: bool = False, ) -> "LibrarySupport": """Build a LibrarySupport. @@ -104,6 +106,7 @@ class LibrarySupport(WrapperCpp): ("generateSharedLib", generateSharedLib), ("generateStaticLib", generateStaticLib), ("generateClientParameters", generateClientParameters), + ("generateCompilationFeedback", generateCompilationFeedback), ("generateCppHeader", generateCppHeader), ]: if not isinstance(value, bool): @@ -115,6 +118,7 @@ class LibrarySupport(WrapperCpp): generateSharedLib, generateStaticLib, generateClientParameters, + generateCompilationFeedback, generateCppHeader, ) ) @@ -188,6 +192,28 @@ class LibrarySupport(WrapperCpp): self.cpp().load_client_parameters(library_compilation_result.cpp()) ) + def load_compilation_feedback( + self, compilation_result: LibraryCompilationResult + ) -> CompilationFeedback: + """Load the compilation feedback from the JIT compilation result. + + Args: + compilation_result (JITCompilationResult): result of the JIT compilation + + Raises: + TypeError: if compilation_result is not of type JITCompilationResult + + Returns: + CompilationFeedback: the compilation feedback for the compiled program + """ + if not isinstance(compilation_result, LibraryCompilationResult): + raise TypeError( + f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}" + ) + return CompilationFeedback.wrap( + self.cpp().load_compilation_feedback(compilation_result.cpp()) + ) + def load_server_lambda( self, library_compilation_result: LibraryCompilationResult ) -> LibraryLambda: diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index b00d2c6df..0eae46168 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -44,6 +44,14 @@ jit_load_client_parameters(JITSupport_C support, return *clientParameters; } +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +jit_load_compilation_feedback( + JITSupport_C support, mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, + support.support.loadCompilationFeedback(result)); + return *compilationFeedback; +} + MLIR_CAPI_EXPORTED std::shared_ptr jit_load_server_lambda(JITSupport_C support, mlir::concretelang::JitCompilationResult &result) { @@ -64,10 +72,12 @@ jit_server_call(JITSupport_C support, mlir::concretelang::JITLambda &lambda, MLIR_CAPI_EXPORTED LibrarySupport_C library_support(const char *outputPath, const char *runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, - bool generateClientParameters, bool generateCppHeader) { + bool generateClientParameters, bool generateCompilationFeedback, + bool generateCppHeader) { return LibrarySupport_C{mlir::concretelang::LibrarySupport( outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, - generateClientParameters, generateCppHeader)}; + generateClientParameters, generateCompilationFeedback, + generateCppHeader)}; } std::unique_ptr @@ -87,6 +97,15 @@ library_load_client_parameters( return *clientParameters; } +MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback +library_load_compilation_feedback( + LibrarySupport_C support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, + support.support.loadCompilationFeedback(result)); + return *compilationFeedback; +} + MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda library_load_server_lambda( LibrarySupport_C support, diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 08fd84d08..505de40e9 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -7,6 +7,7 @@ from concrete.compiler import ( LibrarySupport, ClientSupport, CompilationOptions, + CompilationFeedback, ) @@ -27,19 +28,16 @@ def run(engine, args, compilation_result, keyset_cache): Perform required loading, encryption, execution, and decryption.""" # Dev - compilation_feedback = engine.load_compilation_feedback( - compilation_result) - assert(compilation_feedback is not None) + compilation_feedback = engine.load_compilation_feedback(compilation_result) + assert isinstance(compilation_feedback, CompilationFeedback) # Client client_parameters = engine.load_client_parameters(compilation_result) key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments( - client_parameters, key_set, args) + public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) # Server server_lambda = engine.load_server_lambda(compilation_result) evaluation_keys = key_set.get_evaluation_keys() - public_result = engine.server_call( - server_lambda, public_arguments, evaluation_keys) + public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) # Client result = ClientSupport.decrypt_result(key_set, public_result) return result @@ -141,10 +139,8 @@ end_to_end_parallel_fixture = [ } """, ( - np.array([[1, 2, 3, 4], [4, 2, 1, 0], [ - 2, 3, 1, 5]], dtype=np.uint8), - np.array([[1, 2, 3, 4], [4, 2, 1, 1], [ - 2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8), ), np.array([[52, 36], [31, 34], [42, 52]]), id="matmul_eint_int_uint8", @@ -228,8 +224,7 @@ def test_lib_compile_and_run_p_error(keyset_cache): options = CompilationOptions.new("main") options.set_p_error(0.00001) options.set_display_optimizer_choice(True) - compile_run_assert(engine, mlir_input, args, - expected_result, keyset_cache, options) + compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options) def test_lib_compile_and_run_p_error(keyset_cache):