// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #ifndef CONCRETELANG_SUPPORT_LIBRARY_SUPPORT #define CONCRETELANG_SUPPORT_LIBRARY_SUPPORT #include #include #include #include #include #include #include #include namespace mlir { namespace concretelang { namespace clientlib = ::concretelang::clientlib; namespace serverlib = ::concretelang::serverlib; /// LibraryCompilationResult is the result of a compilation to a library. struct LibraryCompilationResult { /// The output directory path where the compilation artifacts have been /// generated. std::string outputDirPath; std::string funcName; }; class LibrarySupport : public LambdaSupport { public: LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "", bool generateSharedLib = true, bool generateStaticLib = true, bool generateClientParameters = true, bool generateCompilationFeedback = true, bool generateCppHeader = true) : outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath), generateSharedLib(generateSharedLib), generateStaticLib(generateStaticLib), generateClientParameters(generateClientParameters), generateCompilationFeedback(generateCompilationFeedback), generateCppHeader(generateCppHeader) {} llvm::Expected> compile(llvm::SourceMgr &program, CompilationOptions options) override { // Setup the compiler engine auto context = CompilationContext::createShared(); concretelang::CompilerEngine engine(context); engine.setCompilationOptions(options); // Compile to a library auto library = engine.compile( program, outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib, generateClientParameters, generateCompilationFeedback, generateCppHeader); if (auto err = library.takeError()) { return std::move(err); } if (!options.clientParametersFuncName.hasValue()) { return StreamStringError("Need to have a funcname to compile library"); } auto result = std::make_unique(); result->outputDirPath = outputPath; result->funcName = *options.clientParametersFuncName; return std::move(result); } using LambdaSupport::compile; /// Load the server lambda from the compilation result. llvm::Expected loadServerLambda(LibraryCompilationResult &result) override { auto lambda = serverlib::ServerLambda::load(result.funcName, result.outputDirPath); if (lambda.has_error()) { return StreamStringError(lambda.error().mesg); } return lambda.value(); } /// Load the client parameters from the compilation result. llvm::Expected loadClientParameters(LibraryCompilationResult &result) override { auto path = CompilerEngine::Library::getClientParametersPath(result.outputDirPath); auto params = ClientParameters::load(path); if (params.has_error()) { return StreamStringError(params.error().mesg); } auto param = llvm::find_if(params.value(), [&](ClientParameters param) { return param.functionName == result.funcName; }); if (param == params.value().end()) { return StreamStringError("ClientLambda: cannot find function(") << result.funcName << ") in client parameters path(" << path << ")"; } return *param; } llvm::Expected loadCompilationFeedback(LibraryCompilationResult &result) override { auto path = CompilerEngine::Library::getCompilationFeedbackPath( result.outputDirPath); auto feedback = CompilationFeedback::load(path); if (feedback.has_error()) { return StreamStringError(feedback.error().mesg); } return feedback.value(); } /// Call the lambda with the public arguments. llvm::Expected> serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args, clientlib::EvaluationKeys &evaluationKeys) override { return lambda.call(args, evaluationKeys); } /// Get path to shared library std::string getSharedLibPath() { return CompilerEngine::Library::getSharedLibraryPath(outputPath); } /// Get path to client parameters file std::string getClientParametersPath() { return CompilerEngine::Library::getClientParametersPath(outputPath); } private: std::string outputPath; std::string runtimeLibraryPath; /// Flags to select generated artifacts bool generateSharedLib; bool generateStaticLib; bool generateClientParameters; bool generateCompilationFeedback; bool generateCppHeader; }; } // namespace concretelang } // namespace mlir #endif