From 8867d313eed500ff070b8cbe7bff20f99ff08cd9 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 15 Mar 2022 14:26:15 +0100 Subject: [PATCH] feat(python): Expose Jit and Library compiler support --- .../concretelang-c/Support/CompilerEngine.h | 78 +++- .../concretelang/Support/JitLambdaSupport.h | 3 +- .../Support/LibraryLambdaSupport.h | 3 +- .../lib/Bindings/Python/CompilerAPIModule.cpp | 106 ++++++ .../lib/Bindings/Python/concrete/compiler.py | 206 +++++++++-- compiler/lib/CAPI/Support/CompilerEngine.cpp | 127 ++++++- compiler/lib/Support/JitLambdaSupport.cpp | 8 + compiler/tests/python/test_compiler_engine.py | 332 +++++++++++------- compiler/tests/unittest/end_to_end_jit_fhe.cc | 8 +- 9 files changed, 711 insertions(+), 160 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index ee7ec9468..50a86f889 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -9,6 +9,8 @@ #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" #include "concretelang/Support/JitCompilerEngine.h" +#include "concretelang/Support/JitLambdaSupport.h" +#include "concretelang/Support/LibraryLambdaSupport.h" #include "mlir-c/IR.h" #include "mlir-c/Registration.h" @@ -35,11 +37,79 @@ struct executionArguments { }; typedef struct executionArguments executionArguments; +// JIT Support bindings /////////////////////////////////////////////////////// + +struct JITLambdaSupport_C { + mlir::concretelang::JitLambdaSupport support; +}; +typedef struct JITLambdaSupport_C JITLambdaSupport_C; + +MLIR_CAPI_EXPORTED JITLambdaSupport_C +jit_lambda_support(const char *runtimeLibPath); + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_compile(JITLambdaSupport_C support, const char *module, + const char *funcname); + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +jit_load_client_parameters(JITLambdaSupport_C support, + mlir::concretelang::JitCompilationResult &); + +MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda * +jit_load_server_lambda(JITLambdaSupport_C support, + mlir::concretelang::JitCompilationResult &); + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_server_call(JITLambdaSupport_C support, + mlir::concretelang::JITLambda *lambda, + concretelang::clientlib::PublicArguments &args); + +// Library Support bindings /////////////////////////////////////////////////// + +struct LibraryLambdaSupport_C { + mlir::concretelang::LibraryLambdaSupport support; +}; +typedef struct LibraryLambdaSupport_C LibraryLambdaSupport_C; + +MLIR_CAPI_EXPORTED LibraryLambdaSupport_C +library_lambda_support(const char *outputPath); + +MLIR_CAPI_EXPORTED std::unique_ptr +library_compile(LibraryLambdaSupport_C support, const char *module, + const char *funcname); + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +library_load_client_parameters(LibraryLambdaSupport_C support, + mlir::concretelang::LibraryCompilationResult &); + +MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda +library_load_server_lambda(LibraryLambdaSupport_C support, + mlir::concretelang::LibraryCompilationResult &); + +MLIR_CAPI_EXPORTED std::unique_ptr +library_server_call(LibraryLambdaSupport_C support, + concretelang::serverlib::ServerLambda lambda, + concretelang::clientlib::PublicArguments &args); + +// Client Support bindings /////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED std::unique_ptr +key_set(concretelang::clientlib::ClientParameters clientParameters, + llvm::Optional cache); + +MLIR_CAPI_EXPORTED std::unique_ptr +encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, + concretelang::clientlib::KeySet &keySet, + llvm::ArrayRef args); + +MLIR_CAPI_EXPORTED lambdaArgument +decrypt_result(concretelang::clientlib::KeySet &keySet, + concretelang::clientlib::PublicResult &publicResult); + // Build lambda from a textual representation of an MLIR module -// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if not -// null) as a shared library during compilation, -// a path to activate the use a cache for encryption keys for test purpose -// (unsecure), and a set of flags for parallelization. +// The lambda will have `funcName` as entrypoint, and use runtimeLibPath (if +// not null) as a shared library during compilation, a path to activate the +// use a cache for encryption keys for test purpose (unsecure). MLIR_CAPI_EXPORTED mlir::concretelang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName, const char *runtimeLibPath, const char *keySetCachePath, diff --git a/compiler/include/concretelang/Support/JitLambdaSupport.h b/compiler/include/concretelang/Support/JitLambdaSupport.h index 904b7da13..1a4e87b45 100644 --- a/compiler/include/concretelang/Support/JitLambdaSupport.h +++ b/compiler/include/concretelang/Support/JitLambdaSupport.h @@ -35,8 +35,7 @@ public: JitLambdaSupport( llvm::Optional runtimeLibPath = llvm::None, llvm::function_ref llvmOptPipeline = - mlir::makeOptimizingTransformer(3, 0, nullptr)) - : runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {} + mlir::makeOptimizingTransformer(3, 0, nullptr)); llvm::Expected> compile(llvm::SourceMgr &program, std::string funcname = "main") override; diff --git a/compiler/include/concretelang/Support/LibraryLambdaSupport.h b/compiler/include/concretelang/Support/LibraryLambdaSupport.h index 82edc484a..5d036d34a 100644 --- a/compiler/include/concretelang/Support/LibraryLambdaSupport.h +++ b/compiler/include/concretelang/Support/LibraryLambdaSupport.h @@ -33,8 +33,7 @@ class LibraryLambdaSupport : public LambdaSupport { public: - LibraryLambdaSupport(std::string outputPath = "/tmp/toto") - : outputPath(outputPath) {} + LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {} llvm::Expected> compile(llvm::SourceMgr &program, std::string funcname = "main") override { diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 26dc17199..b6871ec04 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -8,6 +8,7 @@ #include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc" #include "concretelang/Support/Jit.h" #include "concretelang/Support/JitCompilerEngine.h" +#include "concretelang/Support/JitLambdaSupport.h" #include #include #include @@ -20,6 +21,7 @@ #include using mlir::concretelang::JitCompilerEngine; +using mlir::concretelang::JitLambdaSupport; using mlir::concretelang::LambdaArgument; const char *noEmptyStringPtr(std::string &s) { @@ -54,6 +56,110 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( auto_parallelize, loop_parallelize, df_parallelize); }); + pybind11::class_( + m, "JitCompilationResult"); + pybind11::class_(m, "JITLambda"); + pybind11::class_(m, "JITLambdaSupport") + .def(pybind11::init([](std::string runtimeLibPath) { + return jit_lambda_support(runtimeLibPath.c_str()); + })) + .def("compile", + [](JITLambdaSupport_C &support, std::string mlir_program, + std::string func_name) { + return jit_compile(support, mlir_program.c_str(), + func_name.c_str()); + }) + .def("load_client_parameters", + [](JITLambdaSupport_C &support, + mlir::concretelang::JitCompilationResult &result) { + return jit_load_client_parameters(support, result); + }) + .def( + "load_server_lambda", + [](JITLambdaSupport_C &support, + mlir::concretelang::JitCompilationResult &result) { + return jit_load_server_lambda(support, result); + }, + pybind11::return_value_policy::reference) + .def("server_call", + [](JITLambdaSupport_C &support, concretelang::JITLambda *lambda, + clientlib::PublicArguments &publicArguments) { + return jit_server_call(support, lambda, publicArguments); + }); + + pybind11::class_( + m, "LibraryCompilationResult") + .def(pybind11::init([](std::string libraryPath, std::string funcname) { + return mlir::concretelang::LibraryCompilationResult{ + libraryPath, + funcname, + }; + })); + pybind11::class_(m, "LibraryLambda"); + pybind11::class_(m, "LibraryLambdaSupport") + .def(pybind11::init([](std::string outputPath) { + return library_lambda_support(outputPath.c_str()); + })) + .def("compile", + [](LibraryLambdaSupport_C &support, std::string mlir_program, + std::string func_name) { + return library_compile(support, mlir_program.c_str(), + func_name.c_str()); + }) + .def("load_client_parameters", + [](LibraryLambdaSupport_C &support, + mlir::concretelang::LibraryCompilationResult &result) { + return library_load_client_parameters(support, result); + }) + .def( + "load_server_lambda", + [](LibraryLambdaSupport_C &support, + mlir::concretelang::LibraryCompilationResult &result) { + return library_load_server_lambda(support, result); + }, + pybind11::return_value_policy::reference) + .def("server_call", + [](LibraryLambdaSupport_C &support, serverlib::ServerLambda lambda, + clientlib::PublicArguments &publicArguments) { + return library_server_call(support, lambda, publicArguments); + }); + + class ClientSupport {}; + pybind11::class_(m, "ClientSupport") + .def(pybind11::init()) + .def_static( + "key_set", + [](clientlib::ClientParameters clientParameters, + clientlib::KeySetCache *cache) { + auto optCache = + cache == nullptr + ? llvm::None + : llvm::Optional(*cache); + return key_set(clientParameters, optCache); + }, + pybind11::arg().none(false), pybind11::arg().none(true)) + .def_static("encrypt_arguments", + [](clientlib::ClientParameters clientParameters, + clientlib::KeySet &keySet, + std::vector args) { + std::vector argsRef; + for (auto i = 0u; i < args.size(); i++) { + argsRef.push_back(args[i].ptr.get()); + } + return encrypt_arguments(clientParameters, keySet, argsRef); + }) + .def_static("decrypt_result", [](clientlib::KeySet &keySet, + clientlib::PublicResult &publicResult) { + return decrypt_result(keySet, publicResult); + }); + pybind11::class_(m, "KeySetCache") + .def(pybind11::init()); + + pybind11::class_(m, "ClientParameters"); + + pybind11::class_(m, "KeySet"); + pybind11::class_(m, "PublicArguments"); + pybind11::class_(m, "PublicResult"); pybind11::class_(m, "LambdaArgument") .def_static("from_tensor", diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py index 31fac3584..0cbda438c 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ b/compiler/lib/Bindings/Python/concrete/compiler.py @@ -14,6 +14,23 @@ from mlir._mlir_libs._concretelang._compiler import ( from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument from mlir._mlir_libs._concretelang._compiler import round_trip as _round_trip from mlir._mlir_libs._concretelang._compiler import library as _library +from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport +from mlir._mlir_libs._concretelang._compiler import LibraryLambdaSupport +from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport + +from mlir._mlir_libs._concretelang._compiler import ClientParameters + +from mlir._mlir_libs._concretelang._compiler import KeySet +from mlir._mlir_libs._concretelang._compiler import KeySetCache + +from mlir._mlir_libs._concretelang._compiler import PublicResult +from mlir._mlir_libs._concretelang._compiler import PublicArguments + +from mlir._mlir_libs._concretelang._compiler import JitCompilationResult +from mlir._mlir_libs._concretelang._compiler import JITLambda + +from mlir._mlir_libs._concretelang._compiler import LibraryCompilationResult +from mlir._mlir_libs._concretelang._compiler import LibraryLambda import numpy as np @@ -46,7 +63,8 @@ def _lookup_runtime_lib() -> str: for filename in os.listdir(libs_path) if filename.startswith("libConcretelangRuntime") ] - assert len(runtime_library_paths) == 1, "should be one and only one runtime library" + assert len( + runtime_library_paths) == 1, "should be one and only one runtime library" return os.path.join(libs_path, runtime_library_paths[0]) @@ -67,10 +85,10 @@ def round_trip(mlir_str: str) -> str: return _round_trip(mlir_str) -_MLIR_MODULES_TYPE = "mlir_modules must be an `iterable` of `str` or a `str" +_MLIR_MODULES_TYPE = 'mlir_modules must be an `iterable` of `str` or a `str' -def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str: +def library(library_path: str, mlir_modules: Union['Iterable[str]', str]) -> str: """Compile the MLIR inputs to a library. Args: @@ -101,7 +119,7 @@ def library(library_path: str, mlir_modules: Union["Iterable[str]", str]) -> str return _library(library_path, mlir_modules) -def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument": +def create_execution_argument(value: Union[int, np.ndarray]) -> _LambdaArgument: """Create an execution argument holding either an int or tensor value. Args: @@ -115,8 +133,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument """ if not isinstance(value, ACCEPTED_TYPES): raise TypeError( - "value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}" - ) + "value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}") if isinstance(value, ACCEPTED_INTS): if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max): raise TypeError( @@ -134,7 +151,7 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument class CompilerEngine: def __init__(self, mlir_str: str = None): - self._engine = _JitCompilerEngine() + self._engine = JITCompilerSupport() self._lambda = None if mlir_str is not None: self.compile_fhe(mlir_str) @@ -182,17 +199,17 @@ class CompilerEngine: ) unsecure_key_set_cache_path = unsecure_key_set_cache_path or "" if not isinstance(unsecure_key_set_cache_path, str): - raise TypeError("unsecure_key_set_cache_path must be a str") - - self._lambda = self._engine.build_lambda( - mlir_str, - func_name, - runtime_lib_path, - unsecure_key_set_cache_path, - auto_parallelize, - loop_parallelize, - df_parallelize, - ) + raise TypeError( + "unsecure_key_set_cache_path must be a str" + ) + self._compilation_result = self._engine.compile(mlir_str) + self._client_parameters = self._engine.load_client_parameters( + self._compilation_result) + keyset_cache = None + if not unsecure_key_set_cache_path is None: + keyset_cache = KeySetCache(unsecure_key_set_cache_path) + self._key_set = ClientSupport.key_set( + self._client_parameters, keyset_cache) def run(self, *args: List[Union[int, np.ndarray]]) -> Union[int, np.ndarray]: """Run the compiled code. @@ -208,10 +225,59 @@ class CompilerEngine: Returns: int or numpy.array: result of execution. """ - if self._lambda is None: + if self._compilation_result is None: raise RuntimeError("need to compile an MLIR code first") + # Client + public_arguments = ClientSupport.encrypt_arguments(self._client_parameters, + self._key_set, args) + # Server + server_lambda = self._engine.load_server_lambda( + self._compilation_result) + public_result = self._engine.server_call( + server_lambda, public_arguments) + # Client + return ClientSupport.decrypt_result(self._key_set, public_result) + + +class ClientSupport: + def key_set(client_parameters: ClientParameters, cache: KeySetCache = None) -> KeySet: + """Generates a key set according to the given client parameters. + If the cache is set the key set is loaded from it if exists, else the new generated key set is saved in the cache + + Args: + client_parameters: A client parameters specification + cache: An optional cache of key set. + + Returns: + KeySet: the key set + """ + return _ClientSupport.key_set(client_parameters, cache) + + def encrypt_arguments(client_parameters: ClientParameters, key_set: KeySet, args: List[Union[int, np.ndarray]]) -> PublicArguments: + """Export clear arguments to public arguments. + For each arguments this method encrypts the argument if it's declared as encrypted and pack to the public arguments object. + + Args: + client_parameters: A client parameters specification + key_set: A key set used to encrypt encrypted arguments + + Returns: + PublicArguments: the public arguments + """ execution_arguments = [create_execution_argument(arg) for arg in args] - lambda_arg = self._lambda.invoke(execution_arguments) + return _ClientSupport.encrypt_arguments(client_parameters, key_set, execution_arguments) + + def decrypt_result(key_set: KeySet, public_result: PublicResult) -> Union[int, np.ndarray]: + """Decrypt a public result thanks the given key set. + + Args: + key_set: The key set used to decrypt the result. + public_result: The public result to descrypt. + + Returns: + int or numpy.array: The result of decryption. + """ + lambda_arg = _ClientSupport.decrypt_result(key_set, public_result) if lambda_arg.is_scalar(): return lambda_arg.get_scalar() elif lambda_arg.is_tensor(): @@ -220,3 +286,103 @@ class CompilerEngine: return tensor else: raise RuntimeError("unknown return type") + + +class JITCompilerSupport: + def __init__(self, runtime_lib_path=None): + if runtime_lib_path is None: + runtime_lib_path = _lookup_runtime_lib() + self._support = JITLambdaSupport(runtime_lib_path) + + def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult: + """JIT Compile a function define in the mlir_program to its homomorphic equivalent. + + Args: + mlir_program: A textual representation of the mlir program to compile. + func_name: The name of the function to compile. + + Returns: + JITCompilationResult: the result of the JIT compilation. + """ + if not isinstance(mlir_program, str): + raise TypeError("mlir_program must be an `str`") + return self._support.compile(mlir_program, func_name) + + def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters: + """Load the client parameters from the JIT compilation result""" + return self._support.load_client_parameters(compilation_result) + + def load_server_lambda(self, compilation_result: JitCompilationResult) -> JITLambda: + """Load the server lambda from the JIT compilation result""" + return self._support.load_server_lambda(compilation_result) + + def server_call(self, server_lambda: JITLambda, public_arguments: PublicArguments): + """Call the server lambda with public_arguments + + Args: + server_lambda: A server lambda to call + public_arguments: The arguments of the call + + Returns: + PublicResult: the result of the call of the server lambda + """ + return self._support.server_call(server_lambda, public_arguments) + + +class LibraryCompilerSupport: + def __init__(self, outputPath="./out"): + self._library_path = outputPath + self._support = LibraryLambdaSupport(outputPath) + + def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult: + """Compile a function define in the mlir_program to its homomorphic equivalent and save as library. + + Args: + mlir_program: A textual representation of the mlir program to compile. + func_name: The name of the function to compile. + + Returns: + LibraryCompilationResult: the result of the compilation. + """ + if not isinstance(mlir_program, str): + raise TypeError("mlir_program must be an `str`") + if not isinstance(func_name, str): + raise TypeError("mlir_program must be an `str`") + return self._support.compile(mlir_program, func_name) + + def reload(self, func_name: str = "main") -> LibraryCompilationResult: + """Reload the library compilation result from the outputPath. + Args: + library-path: The path of the compiled library. + func_name: The name of the compiled function. + + Returns: + LibraryCompilationResult: the result of a compilation. + """ + if not isinstance(func_name, str): + raise TypeError("func_name must be an `str`") + return LibraryCompilationResult(self._library_path, func_name) + + def load_client_parameters(self, compilation_result: LibraryCompilationResult) -> ClientParameters: + """Load the client parameters from the JIT compilation result""" + if not isinstance(compilation_result, LibraryCompilationResult): + raise TypeError( + "compilation_result must be an `LibraryCompilationResult`") + + return self._support.load_client_parameters(compilation_result) + + def load_server_lambda(self, compilation_result: LibraryCompilationResult) -> LibraryLambda: + """Load the server lambda from the JIT compilation result""" + return self._support.load_server_lambda(compilation_result) + + def server_call(self, server_lambda: LibraryLambda, public_arguments: PublicArguments) -> PublicResult: + """Call the server lambda with public_arguments + + Args: + server_lambda: A server lambda to call + public_arguments: The arguments of the call + + Returns: + PublicResult: the result of the call of the server lambda + """ + return self._support.server_call(server_lambda, public_arguments) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 0aeb76097..8145d957b 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -11,9 +11,134 @@ #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" #include "concretelang/Support/JitCompilerEngine.h" +#include "concretelang/Support/JitLambdaSupport.h" using mlir::concretelang::JitCompilerEngine; +#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ + auto VARNAME = EXPECTED; \ + if (auto err = VARNAME.takeError()) { \ + throw std::runtime_error(llvm::toString(std::move(err))); \ + } + +// JIT Support bindings /////////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED JITLambdaSupport_C +jit_lambda_support(const char *runtimeLibPath) { + llvm::StringRef str(runtimeLibPath); + auto opt = str.empty() ? llvm::None : llvm::Optional(str); + return JITLambdaSupport_C{mlir::concretelang::JitLambdaSupport(opt)}; +} + +std::unique_ptr +jit_compile(JITLambdaSupport_C support, const char *module, + const char *funcname) { + mlir::concretelang::JitLambdaSupport esupport; + GET_OR_THROW_LLVM_EXPECTED(compilationResult, + esupport.compile(module, funcname)); + return std::move(*compilationResult); +} + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +jit_load_client_parameters(JITLambdaSupport_C support, + mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(clientParameters, + support.support.loadClientParameters(result)); + return *clientParameters; +} + +MLIR_CAPI_EXPORTED mlir::concretelang::JITLambda * +jit_load_server_lambda(JITLambdaSupport_C support, + mlir::concretelang::JitCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(serverLambda, + support.support.loadServerLambda(result)); + return *serverLambda; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +jit_server_call(JITLambdaSupport_C support, + mlir::concretelang::JITLambda *lambda, + concretelang::clientlib::PublicArguments &args) { + GET_OR_THROW_LLVM_EXPECTED(publicResult, + support.support.serverCall(lambda, args)); + return std::move(*publicResult); +} + +// Library Support bindings /////////////////////////////////////////////////// +MLIR_CAPI_EXPORTED LibraryLambdaSupport_C +library_lambda_support(const char *outputPath) { + return LibraryLambdaSupport_C{ + mlir::concretelang::LibraryLambdaSupport(outputPath)}; +} + +std::unique_ptr +library_compile(LibraryLambdaSupport_C support, const char *module, + const char *funcname) { + GET_OR_THROW_LLVM_EXPECTED(compilationResult, + support.support.compile(module, funcname)); + return std::move(*compilationResult); +} + +MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters +library_load_client_parameters( + LibraryLambdaSupport_C support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(clientParameters, + support.support.loadClientParameters(result)); + return *clientParameters; +} + +MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda +library_load_server_lambda( + LibraryLambdaSupport_C support, + mlir::concretelang::LibraryCompilationResult &result) { + GET_OR_THROW_LLVM_EXPECTED(serverLambda, + support.support.loadServerLambda(result)); + return *serverLambda; +} + +MLIR_CAPI_EXPORTED std::unique_ptr +library_server_call(LibraryLambdaSupport_C support, + concretelang::serverlib::ServerLambda lambda, + concretelang::clientlib::PublicArguments &args) { + GET_OR_THROW_LLVM_EXPECTED(publicResult, + support.support.serverCall(lambda, args)); + return std::move(*publicResult); +} + +// Client Support bindings /////////////////////////////////////////////////// + +MLIR_CAPI_EXPORTED std::unique_ptr +key_set(concretelang::clientlib::ClientParameters clientParameters, + llvm::Optional cache) { + GET_OR_THROW_LLVM_EXPECTED( + ks, (mlir::concretelang::LambdaSupport::keySet(clientParameters, + cache))); + return std::move(*ks); +} + +MLIR_CAPI_EXPORTED std::unique_ptr +encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters, + concretelang::clientlib::KeySet &keySet, + llvm::ArrayRef args) { + GET_OR_THROW_LLVM_EXPECTED( + publicArguments, + (mlir::concretelang::LambdaSupport::exportArguments( + clientParameters, keySet, args))); + return std::move(*publicArguments); +} + +MLIR_CAPI_EXPORTED lambdaArgument +decrypt_result(concretelang::clientlib::KeySet &keySet, + concretelang::clientlib::PublicResult &publicResult) { + GET_OR_THROW_LLVM_EXPECTED( + result, mlir::concretelang::typedResult< + std::unique_ptr>( + keySet, publicResult)); + lambdaArgument result_{std::move(*result)}; + return std::move(result_); +} + mlir::concretelang::JitCompilerEngine::Lambda buildLambda(const char *module, const char *funcName, const char *runtimeLibPath, const char *keySetCachePath, @@ -216,7 +341,7 @@ std::string library(std::string libraryPath, using namespace mlir::concretelang; JitCompilerEngine ce{CompilationContext::createShared()}; - auto lib = ce.compile(mlir_modules, libraryPath); + auto lib = ce.compile(mlir_modules, libraryPath); if (!lib) { throw std::runtime_error("Can't link: " + llvm::toString(lib.takeError())); } diff --git a/compiler/lib/Support/JitLambdaSupport.cpp b/compiler/lib/Support/JitLambdaSupport.cpp index 1dba84f4c..5d6b85842 100644 --- a/compiler/lib/Support/JitLambdaSupport.cpp +++ b/compiler/lib/Support/JitLambdaSupport.cpp @@ -4,12 +4,20 @@ // for license information. #include +#include +#include namespace mlir { namespace concretelang { +JitLambdaSupport::JitLambdaSupport( + llvm::Optional runtimeLibPath, + llvm::function_ref llvmOptPipeline) + : runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {} + llvm::Expected> JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) { + // Setup the compiler engine auto context = std::make_shared(); concretelang::CompilerEngine engine(context); diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 530d331e7..9cb72f894 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -4,80 +4,116 @@ import tempfile import pytest import numpy as np from concrete.compiler import CompilerEngine, library +from lib.Bindings.Python.concrete.compiler import JITCompilerSupport, LibraryCompilerSupport +from lib.Bindings.Python.concrete.compiler import ClientSupport +from lib.Bindings.Python.concrete.compiler import KeySetCache -KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') +KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') -@pytest.mark.parametrize( - "mlir_input, args, expected_result", - [ - pytest.param( - """ +keySetCacheTest = KeySetCache(KEY_SET_CACHE_PATH) + + +def compile_and_run(engine, mlir_input, args, expected_result): + compilation_result = engine.compile(mlir_input) + # Client + client_parameters = engine.load_client_parameters(compilation_result) + key_set = ClientSupport.key_set(client_parameters, keySetCacheTest) + public_arguments = ClientSupport.encrypt_arguments( + client_parameters, key_set, args) + # Server + server_lambda = engine.load_server_lambda(compilation_result) + public_result = engine.server_call(server_lambda, public_arguments) + # Client + result = ClientSupport.decrypt_result(key_set, public_result) + # Check result + assert type(expected_result) == type(result) + if isinstance(expected_result, int): + assert result == expected_result + else: + assert np.all(result == expected_result) + + +end_to_end_fixture = [ + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (5, 7), - 12, - id="add_eint_int", - ), - pytest.param( - """ + (5, 7), + 12, + id="add_eint_int", + ), + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)), - 9, - id="add_eint_int_with_ndarray_as_scalar", - ), - pytest.param( - """ + (np.array(4, dtype=np.uint8), np.array(5, dtype=np.uint8)), + 9, + id="add_eint_int_with_ndarray_as_scalar", + ), + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (np.uint8(3), np.uint8(5)), - 8, - id="add_eint_int_with_np_uint8_as_scalar", - ), - pytest.param( - """ + (np.uint8(3), np.uint8(5)), + 8, + id="add_eint_int_with_np_uint8_as_scalar", + ), + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (np.uint16(3), np.uint16(5)), - 8, - id="add_eint_int_with_np_uint16_as_scalar", - ), - pytest.param( - """ + (np.uint16(3), np.uint16(5)), + 8, + id="add_eint_int_with_np_uint16_as_scalar", + ), + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (np.uint32(3), np.uint32(5)), - 8, - id="add_eint_int_with_np_uint32_as_scalar", - ), - pytest.param( - """ + (np.uint32(3), np.uint32(5)), + 8, + id="add_eint_int_with_np_uint32_as_scalar", + ), + pytest.param( + """ func @main(%arg0: !FHE.eint<7>, %arg1: i8) -> !FHE.eint<7> { %1 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } """, - (np.uint64(3), np.uint64(5)), - 8, - id="add_eint_int_with_np_uint64_as_scalar", - ), - pytest.param( - """ + (np.uint64(3), np.uint64(5)), + 8, + id="add_eint_int_with_np_uint64_as_scalar", + ), + pytest.param( + """ + func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { + %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """, + (73,), + 73, + id="apply_lookup_table", + ), + pytest.param( + """ func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> { %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : @@ -85,15 +121,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') return %ret : !FHE.eint<7> } """, - ( - np.array([1, 2, 3, 4], dtype=np.uint8), - np.array([4, 3, 2, 1], dtype=np.uint8), - ), - 20, - id="dot_eint_int_uint8", + ( + np.array([1, 2, 3, 4], dtype=np.uint8), + np.array([4, 3, 2, 1], dtype=np.uint8), ), - pytest.param( - """ + 20, + id="dot_eint_int_uint8", + ), + pytest.param( + """ func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> { %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : @@ -101,15 +137,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') return %ret : !FHE.eint<7> } """, - ( - np.array([1, 2, 3, 4], dtype=np.uint16), - np.array([4, 3, 2, 1], dtype=np.uint16), - ), - 20, - id="dot_eint_int_uint16", + ( + np.array([1, 2, 3, 4], dtype=np.uint16), + np.array([4, 3, 2, 1], dtype=np.uint16), ), - pytest.param( - """ + 20, + id="dot_eint_int_uint16", + ), + pytest.param( + """ func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> { %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : @@ -117,15 +153,15 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') return %ret : !FHE.eint<7> } """, - ( - np.array([1, 2, 3, 4], dtype=np.uint32), - np.array([4, 3, 2, 1], dtype=np.uint32), - ), - 20, - id="dot_eint_int_uint32", + ( + np.array([1, 2, 3, 4], dtype=np.uint32), + np.array([4, 3, 2, 1], dtype=np.uint32), ), - pytest.param( - """ + 20, + id="dot_eint_int_uint32", + ), + pytest.param( + """ func @main(%arg0: tensor<4x!FHE.eint<7>>, %arg1: tensor<4xi8>) -> !FHE.eint<7> { %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : @@ -133,91 +169,128 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache') return %ret : !FHE.eint<7> } """, - ( - np.array([1, 2, 3, 4], dtype=np.uint64), - np.array([4, 3, 2, 1], dtype=np.uint64), - ), - 20, - id="dot_eint_int_uint64", + ( + np.array([1, 2, 3, 4], dtype=np.uint64), + np.array([4, 3, 2, 1], dtype=np.uint64), ), - pytest.param( - """ + 20, + id="dot_eint_int_uint64", + ), + pytest.param( + """ func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!FHE.eint<6>> { %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4xi7>) -> tensor<4x!FHE.eint<6>> return %res : tensor<4x!FHE.eint<6>> } """, - ( - np.array([31, 6, 12, 9], dtype=np.uint8), - np.array([32, 9, 2, 3], dtype=np.uint8), - ), - np.array([63, 15, 14, 12]), - id="add_eint_int_1D", + ( + np.array([31, 6, 12, 9], dtype=np.uint8), + np.array([32, 9, 2, 3], dtype=np.uint8), ), - pytest.param( - """ + np.array([63, 15, 14, 12]), + id="add_eint_int_1D", + ), + pytest.param( + """ func @main(%a0: tensor<4x4x!FHE.eint<6>>, %a1: tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> { %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x4x!FHE.eint<6>>, tensor<4x4xi7>) -> tensor<4x4x!FHE.eint<6>> return %res : tensor<4x4x!FHE.eint<6>> } """, - ( - np.array( - [[31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9], [31, 6, 12, 9]], - dtype=np.uint8, - ), - np.array( - [[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]], - dtype=np.uint8, - ), - ), + ( np.array( - [ - [63, 15, 14, 12], - [63, 15, 14, 12], - [63, 15, 14, 12], - [63, 15, 14, 12], - ], + [[31, 6, 12, 9], [31, 6, 12, 9], [ + 31, 6, 12, 9], [31, 6, 12, 9]], + dtype=np.uint8, + ), + np.array( + [[32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3], [32, 9, 2, 3]], dtype=np.uint8, ), - id="add_eint_int_2D", ), - pytest.param( - """ + np.array( + [ + [63, 15, 14, 12], + [63, 15, 14, 12], + [63, 15, 14, 12], + [63, 15, 14, 12], + ], + dtype=np.uint8, + ), + id="add_eint_int_2D", + ), + pytest.param( + """ func @main(%a0: tensor<2x2x2x!FHE.eint<6>>, %a1: tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> { %res = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<2x2x2x!FHE.eint<6>>, tensor<2x2x2xi7>) -> tensor<2x2x2x!FHE.eint<6>> return %res : tensor<2x2x2x!FHE.eint<6>> } """, - ( - np.array( - [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], - dtype=np.uint8, - ), - np.array( - [[[9, 10], [11, 12]], [[13, 14], [15, 16]]], - dtype=np.uint8, - ), - ), + ( np.array( - [[[10, 12], [14, 16]], [[18, 20], [22, 24]]], + [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], + dtype=np.uint8, + ), + np.array( + [[[9, 10], [11, 12]], [[13, 14], [15, 16]]], dtype=np.uint8, ), - id="add_eint_int_3D", ), - ], -) -def test_compile_and_run(mlir_input, args, expected_result): - engine = CompilerEngine() - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) - if isinstance(expected_result, int): - assert engine.run(*args) == expected_result - else: - # numpy array - assert np.all(engine.run(*args) == expected_result) + np.array( + [[[10, 12], [14, 16]], [[18, 20], [22, 24]]], + dtype=np.uint8, + ), + id="add_eint_int_3D", + ), +] @pytest.mark.parametrize( + "mlir_input, args, expected_result", + end_to_end_fixture +) +def test_jit_compile_and_run(mlir_input, args, expected_result): + engine = JITCompilerSupport() + compile_and_run(engine, mlir_input, args, expected_result) + + +@pytest.mark.parametrize( + "mlir_input, args, expected_result", + end_to_end_fixture +) +def test_lib_compile_and_run(mlir_input, args, expected_result): + engine = LibraryCompilerSupport("py_test_lib_compile_and_run") + compile_and_run(engine, mlir_input, args, expected_result) + + +@pytest.mark.parametrize( + "mlir_input, args, expected_result", + end_to_end_fixture +) +def test_lib_compile_reload_and_run(mlir_input, args, expected_result): + engine = LibraryCompilerSupport("test_lib_compile_reload_and_run") + # Here don't save compilation result, reload + engine.compile(mlir_input) + compilation_result = engine.reload() + # Client + client_parameters = engine.load_client_parameters(compilation_result) + key_set = ClientSupport.key_set(client_parameters, keySetCacheTest) + public_arguments = ClientSupport.encrypt_arguments( + client_parameters, key_set, args) + # Server + server_lambda = engine.load_server_lambda(compilation_result) + public_result = engine.server_call(server_lambda, public_arguments) + # Client + result = ClientSupport.decrypt_result(key_set, public_result) + # Check result + assert type(expected_result) == type(result) + if isinstance(expected_result, int): + assert result == expected_result + else: + assert np.all(result == expected_result) + + +@ pytest.mark.parametrize( "mlir_input, args", [ pytest.param( @@ -234,8 +307,9 @@ def test_compile_and_run(mlir_input, args, expected_result): ) def test_compile_and_run_invalid_arg_number(mlir_input, args): engine = CompilerEngine() - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) - with pytest.raises(ValueError, match=r"wrong number of arguments"): + engine.compile_fhe( + mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) + with pytest.raises(RuntimeError, match=r"function has arity 2 but is applied to too many arguments"): engine.run(*args) @@ -259,7 +333,8 @@ def test_compile_and_run_invalid_arg_number(mlir_input, args): ) def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): engine = CompilerEngine() - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) + engine.compile_fhe( + mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) assert abs(engine.run(*args) - expected_result) / tab_size < 0.1 @@ -281,8 +356,9 @@ def test_compile_and_run_tlu(mlir_input, args, expected_result, tab_size): ) def test_compile_invalid(mlir_input): engine = CompilerEngine() - with pytest.raises(RuntimeError, match=r"Compilation failed:"): - engine.compile_fhe(mlir_input, unsecure_key_set_cache_path = KEY_SET_CACHE_PATH) + with pytest.raises(RuntimeError, match=r"cannot find the function for generate client parameters"): + engine.compile_fhe( + mlir_input, unsecure_key_set_cache_path=KEY_SET_CACHE_PATH) MODULE_1 = """ diff --git a/compiler/tests/unittest/end_to_end_jit_fhe.cc b/compiler/tests/unittest/end_to_end_jit_fhe.cc index 40afd5c37..0e88a6793 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhe.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhe.cc @@ -14,7 +14,7 @@ \ auto desc = GetParam(); \ \ - LambdaSupport support; \ + auto support = LambdaSupport; \ \ /* 1 - Compile the program */ \ auto compilationResult = support.compile(desc.program); \ @@ -84,8 +84,10 @@ /// Instantiate the test suite for Jit INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES( - JitTest, mlir::concretelang::JitLambdaSupport) + JitTest, mlir::concretelang::JitLambdaSupport()) /// Instantiate the test suite for Jit INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES( - LibraryTest, mlir::concretelang::LibraryLambdaSupport) \ No newline at end of file + LibraryTest, + mlir::concretelang::LibraryLambdaSupport("/tmp/end_to_end_test_" + + desc.description)) \ No newline at end of file