From f4673e82765a0b7cc6dcf7a50914147ae6624ffb Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 9 Sep 2022 23:11:49 +0200 Subject: [PATCH] feat(compiler): First draft or compilation feedback --- .../concretelang/ClientLib/ClientParameters.h | 36 +++++ .../ClientLib/EncryptedArguments.h | 12 -- .../Support/CompilationFeedback.h | 61 ++++++++ .../concretelang/Support/CompilerEngine.h | 15 +- .../include/concretelang/Support/JITSupport.h | 6 + .../concretelang/Support/LambdaSupport.h | 4 + .../concretelang/Support/LibrarySupport.h | 11 ++ .../concretelang/Support/V0Parameters.h | 2 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 13 ++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilationFeedback.cpp | 140 ++++++++++++++++++ compiler/lib/Support/CompilerEngine.cpp | 70 +++++++-- compiler/lib/Support/JITSupport.cpp | 6 +- compiler/lib/Support/V0Parameters.cpp | 3 + compiler/src/main.cpp | 3 +- .../end_to_end_tests/end_to_end_jit_fhe.cc | 5 + compiler/tests/python/test_compilation.py | 19 ++- 17 files changed, 370 insertions(+), 37 deletions(-) create mode 100644 compiler/include/concretelang/Support/CompilationFeedback.h create mode 100644 compiler/lib/Support/CompilationFeedback.cpp diff --git a/compiler/include/concretelang/ClientLib/ClientParameters.h b/compiler/include/concretelang/ClientLib/ClientParameters.h index 58112eb54..cad26c1ed 100644 --- a/compiler/include/concretelang/ClientLib/ClientParameters.h +++ b/compiler/include/concretelang/ClientLib/ClientParameters.h @@ -18,6 +18,19 @@ #include namespace concretelang { + +inline size_t bitWidthAsWord(size_t exactBitWidth) { + if (exactBitWidth <= 8) + return 8; + if (exactBitWidth <= 16) + return 16; + if (exactBitWidth <= 32) + return 32; + if (exactBitWidth <= 64) + return 64; + assert(false && "Bit witdh > 64 not supported"); +} + namespace clientlib { using concretelang::error::StringError; @@ -44,6 +57,7 @@ struct LweSecretKeyParam { void hash(size_t &seed); inline uint64_t lweDimension() { return dimension; } inline uint64_t lweSize() { return dimension + 1; } + inline uint64_t byteSize() { return lweSize() * 8; } }; static bool operator==(const LweSecretKeyParam &lhs, const LweSecretKeyParam &rhs) { @@ -60,6 +74,11 @@ struct BootstrapKeyParam { Variance variance; void hash(size_t &seed); + + uint64_t byteSize(uint64_t inputLweSize, uint64_t outputLweSize) { + return inputLweSize * level * (glweDimension + 1) * (glweDimension + 1) * + outputLweSize * 8; + } }; static inline bool operator==(const BootstrapKeyParam &lhs, const BootstrapKeyParam &rhs) { @@ -78,6 +97,10 @@ struct KeyswitchKeyParam { Variance variance; void hash(size_t &seed); + + size_t byteSize(size_t inputLweSize, size_t outputLweSize) { + return level * inputLweSize * outputLweSize * 8; + } }; static inline bool operator==(const KeyswitchKeyParam &lhs, const KeyswitchKeyParam &rhs) { @@ -125,6 +148,19 @@ struct CircuitGate { CircuitGateShape shape; bool isEncrypted() { return encryption.hasValue(); } + + /// byteSize returns the size in bytes for this gate. + size_t byteSize(std::map secretKeys) { + auto width = shape.width; + auto numElts = shape.size == 0 ? 1 : shape.size; + if (isEncrypted()) { + auto skParam = secretKeys.find(encryption->secretKeyID); + assert(skParam != secretKeys.end()); + return 8 * skParam->second.lweSize() * numElts; + } + width = bitWidthAsWord(width) / 8; + return width * numElts; + } }; static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) { return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape; diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 513f70efb..264c4ec78 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -23,18 +23,6 @@ using concretelang::error::StringError; class PublicArguments; -inline size_t bitWidthAsWord(size_t exactBitWidth) { - if (exactBitWidth <= 8) - return 8; - if (exactBitWidth <= 16) - return 16; - if (exactBitWidth <= 32) - return 32; - if (exactBitWidth <= 64) - return 64; - assert(false && "Bit witdh > 64 not supported"); -} - /// Temporary object used to hold and encrypt parameters before calling a /// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...). /// Otherwise convert it to a PublicArguments and use diff --git a/compiler/include/concretelang/Support/CompilationFeedback.h b/compiler/include/concretelang/Support/CompilationFeedback.h new file mode 100644 index 000000000..35e99205c --- /dev/null +++ b/compiler/include/concretelang/Support/CompilationFeedback.h @@ -0,0 +1,61 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_SUPPORT_COMPILATIONFEEDBACK_H_ +#define CONCRETELANG_SUPPORT_COMPILATIONFEEDBACK_H_ + +#include +#include + +#include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/Common/Error.h" +#include "concretelang/Conversion/Utils/GlobalFHEContext.h" +#include "llvm/Support/Error.h" + +namespace mlir { +namespace concretelang { + +using StringError = ::concretelang::error::StringError; + +struct CompilationFeedback { + double complexity; + + /// @brief the total number of bytes of secret keys + size_t totalSecretKeysSize; + + /// @brief the total number of bytes of bootstrap keys + size_t totalBootstrapKeysSize; + + /// @brief the total number of bytes of keyswitch keys + size_t totalKeyswitchKeysSize; + + /// @brief the total number of bytes of inputs + size_t totalInputsSize; + + /// @brief the total number of bytes of outputs + size_t totalOutputsSize; + + /// Fill the sizes from the client parameters. + void + fillFromClientParameters(::concretelang::clientlib::ClientParameters params); + + /// Load the compilation feedback from a path + static outcome::checked + load(std::string path); +}; + +llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &); +bool fromJSON(const llvm::json::Value, + mlir::concretelang::CompilationFeedback &, llvm::json::Path); + +} // namespace concretelang +} // namespace mlir + +static inline llvm::raw_ostream & +operator<<(llvm::raw_string_ostream &OS, + mlir::concretelang::CompilationFeedback cp) { + return OS << llvm::formatv("{0:2}", toJSON(cp)); +} +#endif diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 2bd608dc2..c49da57c5 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -83,6 +83,7 @@ public: llvm::Optional> mlirModuleRef; llvm::Optional clientParameters; + llvm::Optional feedback; std::unique_ptr llvmModule; llvm::Optional fheContext; @@ -94,6 +95,8 @@ public: std::string outputDirPath; std::vector objectsPath; std::vector clientParametersList; + std::vector + compilationFeedbackList; /// Path to the runtime library. Will be linked to the output library if set std::string runtimeLibraryPath; bool cleanUp; @@ -110,7 +113,8 @@ public: llvm::Expected addCompilation(CompilationResult &compilation); /// Emit the library artifacts with the previously added compilation result llvm::Error emitArtifacts(bool sharedLib, bool staticLib, - bool clientParameters, bool cppHeader); + bool clientParameters, bool compilationFeedback, + bool cppHeader); /// After a shared library has been emitted, its path is here std::string sharedLibraryPath; /// After a static library has been emitted, its path is here @@ -122,9 +126,12 @@ public: /// Returns the path of the static library static std::string getStaticLibraryPath(std::string outputDirPath); - /// Returns the path of the static library + /// Returns the path of the client parameters static std::string getClientParametersPath(std::string outputDirPath); + /// Returns the path of the compilation feedback + static std::string getCompilationFeedbackPath(std::string outputDirPath); + // For advanced use const static std::string OBJECT_EXT, LINKER, LINKER_SHARED_OPT, AR, AR_STATIC_OPT, DOT_STATIC_LIB_EXT, DOT_SHARED_LIB_EXT; @@ -141,6 +148,8 @@ public: llvm::Expected emitShared(); /// Emit a json ClientParameters corresponding to library content llvm::Expected emitClientParametersJSON(); + /// Emit a json CompilationFeedback corresponding to library content + llvm::Expected emitCompilationFeedbackJSON(); /// Emit a client header file for this corresponding to library content llvm::Expected emitCppHeader(); }; @@ -211,6 +220,7 @@ public: compile(std::vector inputs, std::string outputDirPath, std::string runtimeLibraryPath = "", bool generateSharedLib = true, bool generateStaticLib = true, bool generateClientParameters = true, + bool generateCompilationFeedback = true, bool generateCppHeader = true); /// Compile and emit artifact to the given outputDirPath from an LLVM source @@ -219,6 +229,7 @@ public: compile(llvm::SourceMgr &sm, std::string outputDirPath, std::string runtimeLibraryPath = "", bool generateSharedLib = true, bool generateStaticLib = true, bool generateClientParameters = true, + bool generateCompilationFeedback = true, bool generateCppHeader = true); void setCompilationOptions(CompilationOptions &options) { diff --git a/compiler/include/concretelang/Support/JITSupport.h b/compiler/include/concretelang/Support/JITSupport.h index bae99423d..bc890477a 100644 --- a/compiler/include/concretelang/Support/JITSupport.h +++ b/compiler/include/concretelang/Support/JITSupport.h @@ -25,6 +25,7 @@ namespace clientlib = ::concretelang::clientlib; struct JitCompilationResult { std::shared_ptr lambda; clientlib::ClientParameters clientParameters; + CompilationFeedback feedback; }; /// JITSupport is the instantiated LambdaSupport for the Jit Compilation. @@ -49,6 +50,11 @@ public: return result.clientParameters; } + llvm::Expected + loadCompilationFeedback(JitCompilationResult &result) override { + return result.feedback; + } + llvm::Expected> serverCall(std::shared_ptr lambda, clientlib::PublicArguments &args, diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h index 8fa73faf0..32aad6c72 100644 --- a/compiler/include/concretelang/Support/LambdaSupport.h +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -273,6 +273,10 @@ public: llvm::Expected virtual loadClientParameters( CompilationResult &result) = 0; + /// Load the compilation feedback from the compilation result. + llvm::Expected virtual loadCompilationFeedback( + CompilationResult &result) = 0; + /// Call the lambda with the public arguments. llvm::Expected> virtual serverCall( Lambda lambda, clientlib::PublicArguments &args, diff --git a/compiler/include/concretelang/Support/LibrarySupport.h b/compiler/include/concretelang/Support/LibrarySupport.h index c973b4c8d..fe6be0a3f 100644 --- a/compiler/include/concretelang/Support/LibrarySupport.h +++ b/compiler/include/concretelang/Support/LibrarySupport.h @@ -101,6 +101,17 @@ public: return *param; } + llvm::Expected + loadCompilationFeedback(LibraryCompilationResult &result) override { + auto path = CompilerEngine::Library::getCompilationFeedbackPath( + result.outputDirPath); + auto feedback = CompilationFeedback::load(path); + if (feedback.has_error()) { + return StreamStringError(feedback.error().mesg); + } + return feedback.value(); + } + /// Call the lambda with the public arguments. llvm::Expected> serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args, diff --git a/compiler/include/concretelang/Support/V0Parameters.h b/compiler/include/concretelang/Support/V0Parameters.h index 8eb5abccf..a6d03d08e 100644 --- a/compiler/include/concretelang/Support/V0Parameters.h +++ b/compiler/include/concretelang/Support/V0Parameters.h @@ -10,6 +10,7 @@ #include "concrete-optimizer.hpp" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" +#include "concretelang/Support/CompilationFeedback.h" namespace mlir { namespace concretelang { @@ -49,6 +50,7 @@ struct Description { } // namespace optimizer llvm::Expected getParameter(optimizer::Description &descr, + CompilationFeedback &feedback, optimizer::Config optimizerConfig); } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index 039ca53e8..d05db18d8 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -72,6 +72,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( options.optimizerConfig.global_p_error = global_p_error; }); + pybind11::class_( + m, "CompilationFeedback"); + pybind11::class_( m, "JITCompilationResult"); pybind11::class_ + +#include "boost/outcome.h" + +#include "concretelang/Support/CompilationFeedback.h" + +namespace mlir { +namespace concretelang { + +void CompilationFeedback::fillFromClientParameters( + ::concretelang::clientlib::ClientParameters params) { + // Compute the size of secret keys + totalSecretKeysSize = 0; + for (auto sk : params.secretKeys) { + totalSecretKeysSize += sk.second.byteSize(); + } + // Compute the boostrap keys size + totalBootstrapKeysSize = 0; + for (auto bsk : params.bootstrapKeys) { + auto bskParam = bsk.second; + auto inputKey = params.secretKeys.find(bskParam.inputSecretKeyID); + assert(inputKey != params.secretKeys.end()); + auto outputKey = params.secretKeys.find(bskParam.outputSecretKeyID); + assert(outputKey != params.secretKeys.end()); + + totalBootstrapKeysSize += bskParam.byteSize(inputKey->second.lweSize(), + outputKey->second.lweSize()); + } + // Compute the keyswitch keys size + totalKeyswitchKeysSize = 0; + for (auto ksk : params.keyswitchKeys) { + auto kskParam = ksk.second; + auto inputKey = params.secretKeys.find(kskParam.inputSecretKeyID); + assert(inputKey != params.secretKeys.end()); + auto outputKey = params.secretKeys.find(kskParam.outputSecretKeyID); + assert(outputKey != params.secretKeys.end()); + totalKeyswitchKeysSize += kskParam.byteSize(inputKey->second.lweSize(), + outputKey->second.lweSize()); + } + // Compute the size of inputs + totalInputsSize = 0; + for (auto gate : params.inputs) { + totalInputsSize += gate.byteSize(params.secretKeys); + } + // Compute the size of outputs + totalOutputsSize = 0; + for (auto gate : params.outputs) { + totalOutputsSize += gate.byteSize(params.secretKeys); + } +} + +outcome::checked +CompilationFeedback::load(std::string jsonPath) { + std::ifstream file(jsonPath); + std::string content((std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + if (file.fail()) { + return StringError("Cannot read file: ") << jsonPath; + } + auto expectedCompFeedback = llvm::json::parse(content); + if (auto err = expectedCompFeedback.takeError()) { + return StringError("Cannot open client parameters: ") + << llvm::toString(std::move(err)) << "\n" + << content << "\n"; + } + return expectedCompFeedback.get(); +} + +llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) { + llvm::json::Object object{ + {"complexity", v.complexity}, + {"totalSecretKeysSize", v.totalSecretKeysSize}, + {"totalBootstrapKeysSize", v.totalBootstrapKeysSize}, + {"totalKeyswitchKeysSize", v.totalKeyswitchKeysSize}, + {"totalInputsSize", v.totalInputsSize}, + {"totalOutputsSize", v.totalOutputsSize}, + }; + return object; +} + +bool fromJSON(const llvm::json::Value j, + mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) { + auto obj = j.getAsObject(); + if (obj == nullptr) { + p.report("should be an object"); + return false; + } + + auto complexity = obj->getInteger("complexity"); + if (!complexity.hasValue()) { + p.report("missing size field"); + return false; + } + v.complexity = *complexity; + + auto totalSecretKeysSize = obj->getInteger("totalSecretKeysSize"); + if (!totalSecretKeysSize.hasValue()) { + p.report("missing totalSecretKeysSize field"); + return false; + } + v.totalSecretKeysSize = *totalSecretKeysSize; + + auto totalBootstrapKeysSize = obj->getInteger("totalBootstrapKeysSize"); + if (!totalBootstrapKeysSize.hasValue()) { + p.report("missing totalBootstrapKeysSize field"); + return false; + } + v.totalBootstrapKeysSize = *totalBootstrapKeysSize; + + auto totalKeyswitchKeysSize = obj->getInteger("totalKeyswitchKeysSize"); + if (!totalKeyswitchKeysSize.hasValue()) { + p.report("missing totalKeyswitchKeysSize field"); + return false; + } + v.totalKeyswitchKeysSize = *totalKeyswitchKeysSize; + + auto totalInputsSize = obj->getInteger("totalInputsSize"); + if (!totalInputsSize.hasValue()) { + p.report("missing totalInputsSize field"); + return false; + } + v.totalInputsSize = *totalInputsSize; + + auto totalOutputsSize = obj->getInteger("totalOutputsSize"); + if (!totalOutputsSize.hasValue()) { + p.report("missing totalOutputsSize field"); + return false; + } + v.totalOutputsSize = *totalOutputsSize; + + return true; +} + +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 8da318f3d..7dcdb36a4 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -192,13 +192,15 @@ llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { if (!descr.get().hasValue()) { return llvm::Error::success(); } - auto v0Params = - getParameter(descr.get().value(), compilerOptions.optimizerConfig); + CompilationFeedback feedback; + auto v0Params = getParameter(descr.get().value(), feedback, + compilerOptions.optimizerConfig); if (auto err = v0Params.takeError()) { return err; } res.fheContext.emplace(mlir::concretelang::V0FHEContext{ descr.get().value().constraint, v0Params.get()}); + res.feedback.emplace(feedback); } return llvm::Error::success(); @@ -347,6 +349,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { return clientParametersOrErr.takeError(); res.clientParameters = clientParametersOrErr.get(); + res.feedback->fillFromClientParameters(*res.clientParameters); } } @@ -440,12 +443,11 @@ CompilerEngine::compile(std::unique_ptr buffer, return this->compile(sm, target, lib); } -llvm::Expected -CompilerEngine::compile(std::vector inputs, - std::string outputDirPath, - std::string runtimeLibraryPath, bool generateSharedLib, - bool generateStaticLib, bool generateClientParameters, - bool generateCppHeader) { +llvm::Expected CompilerEngine::compile( + std::vector inputs, std::string outputDirPath, + std::string runtimeLibraryPath, bool generateSharedLib, + bool generateStaticLib, bool generateClientParameters, + bool generateCompilationFeedback, bool generateCppHeader) { using Library = mlir::concretelang::CompilerEngine::Library; auto outputLib = std::make_shared(outputDirPath, runtimeLibraryPath); auto target = CompilerEngine::Target::LIBRARY; @@ -456,9 +458,9 @@ CompilerEngine::compile(std::vector inputs, << llvm::toString(compilation.takeError()); } } - if (auto err = outputLib->emitArtifacts(generateSharedLib, generateStaticLib, - generateClientParameters, - generateCppHeader)) { + if (auto err = outputLib->emitArtifacts( + generateSharedLib, generateStaticLib, generateClientParameters, + generateCompilationFeedback, generateCppHeader)) { return StreamStringError("Can't emit artifacts: ") << llvm::toString(std::move(err)); } @@ -469,6 +471,7 @@ llvm::Expected CompilerEngine::compile(llvm::SourceMgr &sm, std::string outputDirPath, std::string runtimeLibraryPath, bool generateSharedLib, bool generateStaticLib, bool generateClientParameters, + bool generateCompilationFeedback, bool generateCppHeader) { using Library = mlir::concretelang::CompilerEngine::Library; auto outputLib = std::make_shared(outputDirPath, runtimeLibraryPath); @@ -480,9 +483,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, std::string outputDirPath, << llvm::toString(compilation.takeError()); } - if (auto err = outputLib->emitArtifacts(generateSharedLib, generateStaticLib, - generateClientParameters, - generateCppHeader)) { + if (auto err = outputLib->emitArtifacts( + generateSharedLib, generateStaticLib, generateClientParameters, + generateCompilationFeedback, generateCppHeader)) { return StreamStringError("Can't emit artifacts: ") << llvm::toString(std::move(err)); } @@ -505,7 +508,7 @@ CompilerEngine::Library::getStaticLibraryPath(std::string outputDirPath) { return staticLibraryPath.str().str(); } -/// Returns the path of the static library +/// Returns the path of the client parameter std::string CompilerEngine::Library::getClientParametersPath(std::string outputDirPath) { llvm::SmallString<0> clientParametersPath(outputDirPath); @@ -515,6 +518,14 @@ CompilerEngine::Library::getClientParametersPath(std::string outputDirPath) { return clientParametersPath.str().str(); } +/// Returns the path of the compiler feedback +std::string +CompilerEngine::Library::getCompilationFeedbackPath(std::string outputDirPath) { + llvm::SmallString<0> compilationFeedbackPath(outputDirPath); + llvm::sys::path::append(compilationFeedbackPath, "compilation_feedback.json"); + return compilationFeedbackPath.str().str(); +} + const std::string CompilerEngine::Library::OBJECT_EXT = ".o"; const std::string CompilerEngine::Library::LINKER = "ld"; #ifdef __APPLE__ @@ -558,6 +569,26 @@ CompilerEngine::Library::emitClientParametersJSON() { return clientParamsPath; } +llvm::Expected +CompilerEngine::Library::emitCompilationFeedbackJSON() { + auto path = getCompilationFeedbackPath(outputDirPath); + if (compilationFeedbackList.size() != 1) { + return StreamStringError("multiple compilation feedback not supported"); + } + llvm::json::Value value(compilationFeedbackList[0]); + std::error_code error; + llvm::raw_fd_ostream out(path, error); + + if (error) { + return StreamStringError("cannot emit client parameters, error: ") + << error.message(); + } + out << llvm::formatv("{0:2}", value); + out.close(); + + return path; +} + static std::string ccpResultType(size_t rank) { if (rank == 0) { return "scalar_out"; @@ -663,6 +694,9 @@ CompilerEngine::Library::addCompilation(CompilationResult &compilation) { if (compilation.clientParameters.hasValue()) { clientParametersList.push_back(compilation.clientParameters.getValue()); } + if (compilation.feedback.hasValue()) { + compilationFeedbackList.push_back(compilation.feedback.getValue()); + } return objectPath; } @@ -776,6 +810,7 @@ llvm::Expected CompilerEngine::Library::emitStatic() { llvm::Error CompilerEngine::Library::emitArtifacts(bool sharedLib, bool staticLib, bool clientParameters, + bool compilationFeedback, bool cppHeader) { // Create output directory if doesn't exist llvm::sys::fs::create_directory(outputDirPath); @@ -794,6 +829,11 @@ llvm::Error CompilerEngine::Library::emitArtifacts(bool sharedLib, return err; } } + if (compilationFeedback) { + if (auto err = emitCompilationFeedbackJSON().takeError()) { + return err; + } + } if (cppHeader) { if (auto err = emitCppHeader().takeError()) { return err; diff --git a/compiler/lib/Support/JITSupport.cpp b/compiler/lib/Support/JITSupport.cpp index 9ab819917..9a3dde64a 100644 --- a/compiler/lib/Support/JITSupport.cpp +++ b/compiler/lib/Support/JITSupport.cpp @@ -49,11 +49,13 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) { // Mark the lambda as compiled using DF parallelization result->lambda->setUseDataflow(options.dataflowParallelize || options.autoParallelize); - if (!mlir::concretelang::dfr::_dfr_is_root_node()) + if (!mlir::concretelang::dfr::_dfr_is_root_node()) { result->clientParameters = clientlib::ClientParameters(); - else + } else { result->clientParameters = compilationResult.get().clientParameters.getValue(); + result->feedback = compilationResult.get().feedback.getValue(); + } return std::move(result); } diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp index 7d53466ae..8a62068ae 100644 --- a/compiler/lib/Support/V0Parameters.cpp +++ b/compiler/lib/Support/V0Parameters.cpp @@ -135,6 +135,7 @@ static void display(optimizer::Description &descr, } llvm::Expected getParameter(optimizer::Description &descr, + CompilationFeedback &feedback, optimizer::Config config) { namespace chrono = std::chrono; auto start = chrono::high_resolution_clock::now(); @@ -206,6 +207,8 @@ llvm::Expected getParameter(optimizer::Description &descr, params.largeInteger = lParams; } + feedback.complexity = sol.complexity; + return params; } diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 4fe7625a9..3ac1743c5 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -585,7 +585,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { if (cmdline::action == Action::COMPILE) { auto err = outputLib->emitArtifacts( /*sharedLib=*/true, /*staticLib=*/true, - /*clientParameters=*/true, /*cppHeader=*/true); + /*clientParameters=*/true, /*compilationFeedback=*/true, + /*cppHeader=*/true); if (err) { return mlir::failure(); } diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc index 40367b5eb..2217911ec 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_fhe.cc @@ -3,6 +3,7 @@ #include #include +#include "concretelang/Support/CompilationFeedback.h" #include "concretelang/Support/JITSupport.h" #include "concretelang/Support/LibrarySupport.h" #include "end_to_end_fixture/EndToEndFixture.h" @@ -70,6 +71,10 @@ void compile_and_run_for_config(EndToEndDesc desc, LambdaSupport support, auto serverLambda = support.loadServerLambda(**compilationResult); ASSERT_EXPECTED_SUCCESS(serverLambda); + // Just test that we can load the compilation feedback + auto feedback = support.loadCompilationFeedback(**compilationResult); + ASSERT_EXPECTED_SUCCESS(feedback); + assert_all_test_entries(desc, test_error_rate, support, keySet, evaluationKeys, clientParameters, serverLambda); } diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 0b953f167..08fd84d08 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -26,14 +26,20 @@ def run(engine, args, compilation_result, keyset_cache): """Execute engine on the given arguments. Perform required loading, encryption, execution, and decryption.""" + # Dev + compilation_feedback = engine.load_compilation_feedback( + compilation_result) + assert(compilation_feedback is not None) # Client client_parameters = engine.load_client_parameters(compilation_result) key_set = ClientSupport.key_set(client_parameters, keyset_cache) - public_arguments = ClientSupport.encrypt_arguments(client_parameters, key_set, args) + public_arguments = ClientSupport.encrypt_arguments( + client_parameters, key_set, args) # Server server_lambda = engine.load_server_lambda(compilation_result) evaluation_keys = key_set.get_evaluation_keys() - public_result = engine.server_call(server_lambda, public_arguments, evaluation_keys) + public_result = engine.server_call( + server_lambda, public_arguments, evaluation_keys) # Client result = ClientSupport.decrypt_result(key_set, public_result) return result @@ -135,8 +141,10 @@ end_to_end_parallel_fixture = [ } """, ( - np.array([[1, 2, 3, 4], [4, 2, 1, 0], [2, 3, 1, 5]], dtype=np.uint8), - np.array([[1, 2, 3, 4], [4, 2, 1, 1], [2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 0], [ + 2, 3, 1, 5]], dtype=np.uint8), + np.array([[1, 2, 3, 4], [4, 2, 1, 1], [ + 2, 3, 1, 5]], dtype=np.uint8), ), np.array([[52, 36], [31, 34], [42, 52]]), id="matmul_eint_int_uint8", @@ -220,7 +228,8 @@ def test_lib_compile_and_run_p_error(keyset_cache): options = CompilationOptions.new("main") options.set_p_error(0.00001) options.set_display_optimizer_choice(True) - compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options) + compile_run_assert(engine, mlir_input, args, + expected_result, keyset_cache, options) def test_lib_compile_and_run_p_error(keyset_cache):