From 5b83b700d2666c2cefc1d312087c9049679e05dc Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Thu, 17 Mar 2022 14:34:12 +0100 Subject: [PATCH] enhance(compiler): Expose a compilation options instead of just the funcname --- .../concretelang-c/Support/CompilerEngine.h | 4 +- .../concretelang/Support/CompilerEngine.h | 45 ++++++++ .../concretelang/Support/JitLambdaSupport.h | 2 +- .../concretelang/Support/LambdaSupport.h | 109 +++++++++++++++++- .../Support/LibraryLambdaSupport.h | 10 +- .../lib/Bindings/Python/CompilerAPIModule.cpp | 43 ++++--- .../lib/Bindings/Python/concrete/compiler.py | 12 +- compiler/lib/CAPI/Support/CompilerEngine.cpp | 8 +- compiler/lib/Support/JitLambdaSupport.cpp | 14 ++- compiler/tests/unittest/EndToEndFixture.cpp | 1 - compiler/tests/unittest/end_to_end_jit_test.h | 2 +- 11 files changed, 203 insertions(+), 47 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 50a86f889..4f4fb2a57 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -49,7 +49,7 @@ jit_lambda_support(const char *runtimeLibPath); MLIR_CAPI_EXPORTED std::unique_ptr jit_compile(JITLambdaSupport_C support, const char *module, - const char *funcname); + mlir::concretelang::CompilationOptions options); MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters jit_load_client_parameters(JITLambdaSupport_C support, @@ -76,7 +76,7 @@ library_lambda_support(const char *outputPath); MLIR_CAPI_EXPORTED std::unique_ptr library_compile(LibraryLambdaSupport_C support, const char *module, - const char *funcname); + mlir::concretelang::CompilationOptions options); MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters library_load_client_parameters(LibraryLambdaSupport_C support, diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 6eb806305..15a388f0c 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -36,6 +36,30 @@ protected: llvm::LLVMContext *llvmContext; }; +/// Compilation options allows to configure the compilation pipeline. +struct CompilationOptions { + llvm::Optional v0FHEConstraints; + + bool verifyDiagnostics; + + bool autoParallelize; + bool loopParallelize; + bool dataflowParallelize; + llvm::Optional> fhelinalgTileSizes; + + llvm::Optional clientParametersFuncName; + + CompilationOptions() + : v0FHEConstraints(llvm::None), verifyDiagnostics(false), + autoParallelize(false), loopParallelize(false), + dataflowParallelize(false), clientParametersFuncName(llvm::None){}; + + CompilationOptions(std::string funcname) + : v0FHEConstraints(llvm::None), verifyDiagnostics(false), + autoParallelize(false), loopParallelize(false), + dataflowParallelize(false), clientParametersFuncName(funcname){}; +}; + class CompilerEngine { public: // Result of an invocation of the `CompilerEngine` with optional @@ -176,6 +200,27 @@ public: llvm::Expected compile(llvm::SourceMgr &sm, std::string libraryPath); + void setCompilationOptions(CompilationOptions &options) { + if (options.v0FHEConstraints.hasValue()) { + setFHEConstraints(*options.v0FHEConstraints); + } + + setVerifyDiagnostics(options.verifyDiagnostics); + + setAutoParallelize(options.autoParallelize); + setLoopParallelize(options.loopParallelize); + setDataflowParallelize(options.dataflowParallelize); + + if (options.clientParametersFuncName.hasValue()) { + setGenerateClientParameters(true); + setClientParametersFuncName(*options.clientParametersFuncName); + } + + if (options.fhelinalgTileSizes.hasValue()) { + setFHELinalgTileSizes(*options.fhelinalgTileSizes); + } + } + void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c); void setMaxEintPrecision(size_t v); void setMaxMANP(size_t v); diff --git a/compiler/include/concretelang/Support/JitLambdaSupport.h b/compiler/include/concretelang/Support/JitLambdaSupport.h index 1a4e87b45..7e9be73ae 100644 --- a/compiler/include/concretelang/Support/JitLambdaSupport.h +++ b/compiler/include/concretelang/Support/JitLambdaSupport.h @@ -38,7 +38,7 @@ public: mlir::makeOptimizingTransformer(3, 0, nullptr)); llvm::Expected> - compile(llvm::SourceMgr &program, std::string funcname = "main") override; + compile(llvm::SourceMgr &program, CompilationOptions options) override; using LambdaSupport::compile; llvm::Expected diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index 40361034a..76961605e 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -241,25 +241,29 @@ public: }; template class LambdaSupport { - public: + typedef Lambda lambda; + typedef CompilationResult compilationResult; + virtual ~LambdaSupport() {} /// Compile the mlir program and produces a compilation result if succeed. llvm::Expected> virtual compile( - llvm::SourceMgr &program, std::string funcname = "main"); + llvm::SourceMgr &program, + CompilationOptions options = CompilationOptions("main")); llvm::Expected> - compile(llvm::StringRef program, std::string funcname = "main") { - return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname); + compile(llvm::StringRef program, + CompilationOptions options = CompilationOptions("main")) { + return compile(llvm::MemoryBuffer::getMemBuffer(program), options); } llvm::Expected> compile(std::unique_ptr program, - std::string funcname = "main") { + CompilationOptions options = CompilationOptions("main")) { llvm::SourceMgr sm; sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc()); - return compile(sm, funcname); + return compile(sm, options); } /// Load the server lambda from the compilation result. @@ -312,6 +316,99 @@ public: } }; +template class ClientServer { +public: + static llvm::Expected + create(llvm::StringRef program, + CompilationOptions options = CompilationOptions("main"), + llvm::Optional cache = {}, + LambdaSupport support = LambdaSupport()) { + auto compilationResult = support.compile(program, options); + if (auto err = compilationResult.takeError()) { + return std::move(err); + } + auto lambda = support.loadServerLambda(**compilationResult); + if (auto err = lambda.takeError()) { + return std::move(err); + } + auto clientParameters = support.loadClientParameters(**compilationResult); + if (auto err = clientParameters.takeError()) { + return std::move(err); + } + auto keySet = support.keySet(*clientParameters, cache); + if (auto err = keySet.takeError()) { + return std::move(err); + } + auto f = ClientServer(); + f.lambda = *lambda; + f.compilationResult = std::move(*compilationResult); + f.keySet = std::move(*keySet); + f.clientParameters = *clientParameters; + f.support = support; + return std::move(f); + } + + template + llvm::Expected operator()(llvm::ArrayRef args) { + auto publicArguments = LambdaArgumentAdaptor::exportArguments( + args, clientParameters, *this->keySet); + + if (auto err = publicArguments.takeError()) { + return std::move(err); + } + + auto publicResult = support.serverCall(lambda, **publicArguments); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + return typedResult(*keySet, **publicResult); + } + + template + llvm::Expected operator()(const llvm::ArrayRef args) { + auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args); + if (encryptedArgs.has_error()) { + return StreamStringError(encryptedArgs.error().mesg); + } + auto publicArguments = encryptedArgs.value()->exportPublicArguments( + clientParameters, keySet->runtimeContext()); + if (!publicArguments.has_value()) { + return StreamStringError(publicArguments.error().mesg); + } + auto publicResult = support.serverCall(lambda, *publicArguments.value()); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + return typedResult(*keySet, **publicResult); + } + + template + llvm::Expected operator()(const Args... args) { + auto encryptedArgs = + clientlib::EncryptedArguments::create(*keySet, args...); + if (encryptedArgs.has_error()) { + return StreamStringError(encryptedArgs.error().mesg); + } + auto publicArguments = encryptedArgs.value()->exportPublicArguments( + clientParameters, keySet->runtimeContext()); + if (publicArguments.has_error()) { + return StreamStringError(publicArguments.error().mesg); + } + auto publicResult = support.serverCall(lambda, *publicArguments.value()); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + return typedResult(*keySet, **publicResult); + } + +private: + typename LambdaSupport::lambda lambda; + std::unique_ptr compilationResult; + std::unique_ptr keySet; + clientlib::ClientParameters clientParameters; + LambdaSupport support; +}; + } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Support/LibraryLambdaSupport.h b/compiler/include/concretelang/Support/LibraryLambdaSupport.h index 5d036d34a..a5bac5b67 100644 --- a/compiler/include/concretelang/Support/LibraryLambdaSupport.h +++ b/compiler/include/concretelang/Support/LibraryLambdaSupport.h @@ -36,11 +36,11 @@ public: LibraryLambdaSupport(std::string outputPath) : outputPath(outputPath) {} llvm::Expected> - compile(llvm::SourceMgr &program, std::string funcname = "main") override { + compile(llvm::SourceMgr &program, CompilationOptions options) override { // Setup the compiler engine auto context = CompilationContext::createShared(); concretelang::CompilerEngine engine(context); - engine.setClientParametersFuncName(funcname); + engine.setCompilationOptions(options); // Compile to a library auto library = engine.compile(program, outputPath); @@ -48,9 +48,13 @@ public: return std::move(err); } + if (!options.clientParametersFuncName.hasValue()) { + return StreamStringError("Need to have a funcname to compile library"); + } + auto result = std::make_unique(); result->libraryPath = outputPath; - result->funcName = funcname; + result->funcName = *options.clientParametersFuncName; return std::move(result); } using LambdaSupport::compile; diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index b6871ec04..f40028e7b 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -20,6 +20,7 @@ #include #include +using mlir::concretelang::CompilationOptions; using mlir::concretelang::JitCompilerEngine; using mlir::concretelang::JitLambdaSupport; using mlir::concretelang::LambdaArgument; @@ -43,19 +44,25 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( m.def("terminate_parallelization", &terminateParallelization); - pybind11::class_(m, "JitCompilerEngine") - .def(pybind11::init()) - .def_static("build_lambda", - [](std::string mlir_input, std::string func_name, - std::string runtime_lib_path, std::string keysetcache_path, - bool auto_parallelize, bool loop_parallelize, - bool df_parallelize) { - return buildLambda(mlir_input.c_str(), func_name.c_str(), - noEmptyStringPtr(runtime_lib_path), - noEmptyStringPtr(keysetcache_path), - auto_parallelize, loop_parallelize, - df_parallelize); - }); + pybind11::class_(m, "CompilationOptions") + .def(pybind11::init( + [](std::string funcname) { return CompilationOptions(funcname); })) + .def("set_funcname", + [](CompilationOptions &options, std::string funcname) { + options.clientParametersFuncName = funcname; + }) + .def("set_verify_diagnostics", + [](CompilationOptions &options, bool b) { + options.verifyDiagnostics = b; + }) + .def("auto_parallelize", [](CompilationOptions &options, + bool b) { options.autoParallelize = b; }) + .def("loop_parallelize", [](CompilationOptions &options, + bool b) { options.loopParallelize = b; }) + .def("dataflow_parallelize", [](CompilationOptions &options, bool b) { + options.dataflowParallelize = b; + }); + pybind11::class_( m, "JitCompilationResult"); pybind11::class_(m, "JITLambda"); @@ -65,9 +72,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( })) .def("compile", [](JITLambdaSupport_C &support, std::string mlir_program, - std::string func_name) { - return jit_compile(support, mlir_program.c_str(), - func_name.c_str()); + CompilationOptions options) { + return jit_compile(support, mlir_program.c_str(), options); }) .def("load_client_parameters", [](JITLambdaSupport_C &support, @@ -102,9 +108,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( })) .def("compile", [](LibraryLambdaSupport_C &support, std::string mlir_program, - std::string func_name) { - return library_compile(support, mlir_program.c_str(), - func_name.c_str()); + mlir::concretelang::CompilationOptions options) { + return library_compile(support, mlir_program.c_str(), options); }) .def("load_client_parameters", [](LibraryLambdaSupport_C &support, diff --git a/compiler/lib/Bindings/Python/concrete/compiler.py b/compiler/lib/Bindings/Python/concrete/compiler.py index eb059ffae..5ea5fbfc0 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler.py +++ b/compiler/lib/Bindings/Python/concrete/compiler.py @@ -22,6 +22,8 @@ from mlir._mlir_libs._concretelang._compiler import PublicResult from mlir._mlir_libs._concretelang._compiler import PublicArguments from mlir._mlir_libs._concretelang._compiler import LambdaArgument as _LambdaArgument +from mlir._mlir_libs._concretelang._compiler import CompilationOptions + from mlir._mlir_libs._concretelang._compiler import JITLambdaSupport as _JITLambdaSupport from mlir._mlir_libs._concretelang._compiler import JitCompilationResult from mlir._mlir_libs._concretelang._compiler import JITLambda @@ -259,7 +261,7 @@ class JITCompilerSupport: runtime_lib_path = _lookup_runtime_lib() self._support = _JITLambdaSupport(runtime_lib_path) - def compile(self, mlir_program: str, func_name: str = "main") -> JitCompilationResult: + def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> JitCompilationResult: """JIT Compile a function define in the mlir_program to its homomorphic equivalent. Args: @@ -271,7 +273,7 @@ class JITCompilerSupport: """ if not isinstance(mlir_program, str): raise TypeError("mlir_program must be an `str`") - return self._support.compile(mlir_program, func_name) + return self._support.compile(mlir_program, options) def load_client_parameters(self, compilation_result: JitCompilationResult) -> ClientParameters: """Load the client parameters from the JIT compilation result""" @@ -299,7 +301,7 @@ class LibraryCompilerSupport: self._library_path = outputPath self._support = _LibraryLambdaSupport(outputPath) - def compile(self, mlir_program: str, func_name: str = "main") -> LibraryCompilationResult: + def compile(self, mlir_program: str, options: CompilationOptions = CompilationOptions("main")) -> LibraryCompilationResult: """Compile a function define in the mlir_program to its homomorphic equivalent and save as library. Args: @@ -311,9 +313,9 @@ class LibraryCompilerSupport: """ if not isinstance(mlir_program, str): raise TypeError("mlir_program must be an `str`") - if not isinstance(func_name, str): + if not isinstance(options, CompilationOptions): raise TypeError("mlir_program must be an `str`") - return self._support.compile(mlir_program, func_name) + return self._support.compile(mlir_program, options) def reload(self, func_name: str = "main") -> LibraryCompilationResult: """Reload the library compilation result from the outputPath. diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 8145d957b..8b5d9e318 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -32,10 +32,10 @@ jit_lambda_support(const char *runtimeLibPath) { std::unique_ptr jit_compile(JITLambdaSupport_C support, const char *module, - const char *funcname) { + mlir::concretelang::CompilationOptions options) { mlir::concretelang::JitLambdaSupport esupport; GET_OR_THROW_LLVM_EXPECTED(compilationResult, - esupport.compile(module, funcname)); + esupport.compile(module, options)); return std::move(*compilationResult); } @@ -73,9 +73,9 @@ library_lambda_support(const char *outputPath) { std::unique_ptr library_compile(LibraryLambdaSupport_C support, const char *module, - const char *funcname) { + mlir::concretelang::CompilationOptions options) { GET_OR_THROW_LLVM_EXPECTED(compilationResult, - support.support.compile(module, funcname)); + support.support.compile(module, options)); return std::move(*compilationResult); } diff --git a/compiler/lib/Support/JitLambdaSupport.cpp b/compiler/lib/Support/JitLambdaSupport.cpp index 5d6b85842..701d12c5d 100644 --- a/compiler/lib/Support/JitLambdaSupport.cpp +++ b/compiler/lib/Support/JitLambdaSupport.cpp @@ -16,15 +16,14 @@ JitLambdaSupport::JitLambdaSupport( : runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {} llvm::Expected> -JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) { +JitLambdaSupport::compile(llvm::SourceMgr &program, + CompilationOptions options) { // Setup the compiler engine auto context = std::make_shared(); concretelang::CompilerEngine engine(context); - // We need client parameters to be generated - engine.setGenerateClientParameters(true); - engine.setClientParametersFuncName(funcname); + engine.setCompilationOptions(options); // Compile to LLVM Dialect auto compilationResult = @@ -34,10 +33,15 @@ JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) { return std::move(err); } + if (!options.clientParametersFuncName.hasValue()) { + return StreamStringError("Need to have a funcname to JIT compile"); + } + // Compile from LLVM Dialect to JITLambda auto mlirModule = compilationResult.get().mlirModuleRef->get(); auto lambda = concretelang::JITLambda::create( - funcname, mlirModule, llvmOptPipeline, runtimeLibPath); + *options.clientParametersFuncName, mlirModule, llvmOptPipeline, + runtimeLibPath); if (auto err = lambda.takeError()) { return std::move(err); } diff --git a/compiler/tests/unittest/EndToEndFixture.cpp b/compiler/tests/unittest/EndToEndFixture.cpp index 48137131c..4ea02ffa2 100644 --- a/compiler/tests/unittest/EndToEndFixture.cpp +++ b/compiler/tests/unittest/EndToEndFixture.cpp @@ -1,7 +1,6 @@ #include "EndToEndFixture.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Jit.h" -#include "concretelang/Support/JitCompilerEngine.h" #include "llvm/Support/YAMLParser.h" #include "llvm/Support/YAMLTraits.h" diff --git a/compiler/tests/unittest/end_to_end_jit_test.h b/compiler/tests/unittest/end_to_end_jit_test.h index f15824d7a..783e976be 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.h +++ b/compiler/tests/unittest/end_to_end_jit_test.h @@ -5,7 +5,7 @@ #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Support/CompilerEngine.h" -#include "concretelang/Support/JitCompilerEngine.h" +#include "concretelang/Support/JitLambdaSupport.h" #include "llvm/Support/Path.h" #include "globals.h"