diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 263f665ee..fcd94e494 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -36,9 +36,20 @@ public: template static outcome::checked, StringError> create(KeySet &keySet, Args... args) { - auto arguments = std::make_unique(); - OUTCOME_TRYV(arguments->pushArgs(keySet, args...)); - return arguments; + auto encryptedArgs = std::make_unique(); + OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...)); + return std::move(encryptedArgs); + } + + template + static outcome::checked, StringError> + create(KeySet &keySet, const llvm::ArrayRef args) { + auto encryptedArgs = EncryptedArguments::empty(); + for (size_t i = 0; i < args.size(); i++) { + OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet)); + } + OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet)); + return std::move(encryptedArgs); } static std::unique_ptr empty() { diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 0b56c7b06..6eb806305 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -168,8 +168,12 @@ public: compile(llvm::SourceMgr &sm, Target target, llvm::Optional> lib = {}); - template - llvm::Expected compile(std::vector inputs, + llvm::Expected + compile(std::vector inputs, std::string libraryPath); + + /// Compile and emit artifact to the given libraryPath from an LLVM source + /// manager. + llvm::Expected compile(llvm::SourceMgr &sm, std::string libraryPath); void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c); diff --git a/compiler/include/concretelang/Support/JitCompilerEngine.h b/compiler/include/concretelang/Support/JitCompilerEngine.h index 2e0981451..a5dd5d34e 100644 --- a/compiler/include/concretelang/Support/JitCompilerEngine.h +++ b/compiler/include/concretelang/Support/JitCompilerEngine.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace mlir { @@ -19,200 +20,6 @@ namespace concretelang { using ::concretelang::clientlib::KeySetCache; namespace clientlib = ::concretelang::clientlib; -namespace { -// Generic function template as well as specializations of -// `typedResult` must be declared at namespace scope due to return -// type template specialization - -// Helper function for `JitCompilerEngine::Lambda::operator()` -// implementing type-dependent preparation of the result. -template -llvm::Expected typedResult(clientlib::KeySet &keySet, - clientlib::PublicResult &result); - -// Specialization of `typedResult()` for scalar results, forwarding -// scalar value to caller -template <> -inline llvm::Expected typedResult(clientlib::KeySet &keySet, - clientlib::PublicResult &result) { - auto clearResult = result.asClearTextVector(keySet, 0); - if (!clearResult.has_value()) { - return StreamStringError("typedResult cannot get clear text vector") - << clearResult.error().mesg; - } - if (clearResult.value().size() != 1) { - return StreamStringError("typedResult expect only one value but got ") - << clearResult.value().size(); - } - return clearResult.value()[0]; -} - -template -inline llvm::Expected> -typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { - auto clearResult = result.asClearTextVector(keySet, 0); - if (!clearResult.has_value()) { - return StreamStringError("typedVectorResult cannot get clear text vector") - << clearResult.error().mesg; - } - return std::move(clearResult.value()); -} - -// Specializations of `typedResult()` for vector results, initializing -// an `std::vector` of the right size with the results and forwarding -// it to the caller with move semantics. -// -// Cannot factor out into a template template inline -// llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due -// to ambiguity with scalar template -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } -// template <> -// inline llvm::Expected> -// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { -// return typedVectorResult(keySet, result); -// } -template <> -inline llvm::Expected> -typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { - return typedVectorResult(keySet, result); -} - -template -llvm::Expected> -buildTensorLambdaResult(clientlib::KeySet &keySet, - clientlib::PublicResult &result) { - llvm::Expected> tensorOrError = - typedResult>(keySet, result); - - if (auto err = tensorOrError.takeError()) - return std::move(err); - std::vector tensorDim(result.buffers[0].sizes.begin(), - result.buffers[0].sizes.end() - 1); - - return std::make_unique>>( - *tensorOrError, tensorDim); -} - -// Specialization of `typedResult()` for a single result wrapped into -// a `LambdaArgument`. -template <> -inline llvm::Expected> -typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { - auto gate = keySet.outputGate(0); - // scalar case - if (gate.shape.dimensions.empty()) { - auto clearResult = result.asClearTextVector(keySet, 0); - if (clearResult.has_error()) { - return StreamStringError("typedResult: ") << clearResult.error().mesg; - } - auto res = clearResult.value()[0]; - - return std::make_unique>(res); - } - // tensor case - // auto width = gate.shape.width; - - // if (width > 32) - return buildTensorLambdaResult(keySet, result); - // else if (width > 16) - // return buildTensorLambdaResult(keySet, result); - // else if (width > 8) - // return buildTensorLambdaResult(keySet, result); - // else if (width <= 8) - // return buildTensorLambdaResult(keySet, result); - - // return StreamStringError("Cannot handle scalars with more than 64 bits"); -} - -} // namespace - -// Adaptor class that push arguments specified as instances of -// `LambdaArgument` to `clientlib::EncryptedArguments`. -class JITLambdaArgumentAdaptor { -public: - // Checks if the argument `arg` is an plaintext / encrypted integer - // argument or a plaintext / encrypted tensor argument with a - // backing integer type `IntT` and push the argument to `encryptedArgs`. - // - // Returns `true` if `arg` has one of the types above and its value - // was successfully added to `encryptedArgs`, `false` if none of the types - // matches or an error if a type matched, but adding the argument to - // `encryptedArgs` failed. - template - static inline llvm::Expected - tryAddArg(clientlib::EncryptedArguments &encryptedArgs, - const LambdaArgument &arg, clientlib::KeySet &keySet) { - if (auto ila = arg.dyn_cast>()) { - auto res = encryptedArgs.pushArg(ila->getValue(), keySet); - if (!res.has_value()) { - return StreamStringError(res.error().mesg); - } else { - return true; - } - } else if (auto tla = arg.dyn_cast< - TensorLambdaArgument>>()) { - auto res = - encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet); - if (!res.has_value()) { - return StreamStringError(res.error().mesg); - } else { - return true; - } - } - return false; - } - - // Recursive case for `tryAddArg(...)` - template - static inline llvm::Expected - tryAddArg(clientlib::EncryptedArguments &encryptedArgs, - const LambdaArgument &arg, clientlib::KeySet &keySet) { - llvm::Expected successOrError = - tryAddArg(encryptedArgs, arg, keySet); - - if (!successOrError) - return successOrError.takeError(); - - if (successOrError.get() == false) - return tryAddArg(encryptedArgs, arg, keySet); - else - return true; - } - - // Attempts to push a single argument `arg` to `encryptedArgs`. Returns an - // error if either the argument type is unsupported or if the argument types - // is supported, but adding it to `encryptedArgs` failed. - static inline llvm::Error - addArgument(clientlib::EncryptedArguments &encryptedArgs, - const LambdaArgument &arg, clientlib::KeySet &keySet) { - // Try the supported integer types; size_t needs explicit - // treatment, since it may alias none of the fixed size integer - // types - llvm::Expected successOrError = - JITLambdaArgumentAdaptor::tryAddArg(encryptedArgs, arg, - keySet); - - if (!successOrError) - return successOrError.takeError(); - - if (successOrError.get() == false) - return StreamStringError("Unknown argument type"); - else - return llvm::Error::success(); - } -}; - // A compiler engine that JIT-compiles a source and produces a lambda // object directly invocable through its call operator. class JitCompilerEngine : public CompilerEngine { @@ -246,32 +53,16 @@ public: // Invocation with an dynamic list of arguments of different // types, specified as `LambdaArgument`s template - llvm::Expected - operator()(llvm::ArrayRef lambdaArgs) { - // Encrypt the arguments - auto encryptedArgs = clientlib::EncryptedArguments::empty(); + llvm::Expected operator()(llvm::ArrayRef args) { + auto publicArguments = LambdaArgumentAdaptor::exportArguments( + args, clientParameters, *this->keySet); - for (size_t i = 0; i < lambdaArgs.size(); i++) { - if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( - *encryptedArgs, *lambdaArgs[i], *this->keySet)) { - return std::move(err); - } - } - - auto check = encryptedArgs->checkAllArgs(*this->keySet); - if (check.has_error()) { - return StreamStringError(check.error().mesg); - } - - // Export as public arguments - auto publicArguments = encryptedArgs->exportPublicArguments( - clientParameters, keySet->runtimeContext()); - if (!publicArguments.has_value()) { - return StreamStringError(publicArguments.error().mesg); + if (auto err = publicArguments.takeError()) { + return err; } // Call the lambda - auto publicResult = this->innerLambda->call(*publicArguments.value()); + auto publicResult = this->innerLambda->call(**publicArguments); if (auto err = publicResult.takeError()) { return std::move(err); } @@ -283,22 +74,13 @@ public: template llvm::Expected operator()(const llvm::ArrayRef args) { // Encrypt the arguments - auto encryptedArgs = clientlib::EncryptedArguments::empty(); - - for (size_t i = 0; i < args.size(); i++) { - auto res = encryptedArgs->pushArg(args[i], *keySet); - if (res.has_error()) { - return StreamStringError(res.error().mesg); - } - } - - auto check = encryptedArgs->checkAllArgs(*this->keySet); - if (check.has_error()) { - return StreamStringError(check.error().mesg); + auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args); + if (encryptedArgs.has_error()) { + return StreamStringError(encryptedArgs.error().mesg); } // Export as public arguments - auto publicArguments = encryptedArgs->exportPublicArguments( + auto publicArguments = encryptedArgs.value()->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); diff --git a/compiler/include/concretelang/Support/JitLambdaSupport.h b/compiler/include/concretelang/Support/JitLambdaSupport.h new file mode 100644 index 000000000..904b7da13 --- /dev/null +++ b/compiler/include/concretelang/Support/JitLambdaSupport.h @@ -0,0 +1,69 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT +#define CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT + +#include +#include +#include +#include + +#include +#include +#include + +namespace mlir { +namespace concretelang { + +namespace clientlib = ::concretelang::clientlib; + +/// JitCompilationResult is the result of a Jit compilation, the server JIT +/// lambda and the clientParameters. +struct JitCompilationResult { + std::unique_ptr lambda; + clientlib::ClientParameters clientParameters; +}; + +/// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation. +class JitLambdaSupport + : public LambdaSupport { + +public: + JitLambdaSupport( + llvm::Optional runtimeLibPath = llvm::None, + llvm::function_ref llvmOptPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr)) + : runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {} + + llvm::Expected> + compile(llvm::SourceMgr &program, std::string funcname = "main") override; + using LambdaSupport::compile; + + llvm::Expected + loadServerLambda(JitCompilationResult &result) override { + return result.lambda.get(); + } + + llvm::Expected + loadClientParameters(JitCompilationResult &result) override { + return result.clientParameters; + } + + llvm::Expected> + serverCall(concretelang::JITLambda *lambda, + clientlib::PublicArguments &args) override { + return lambda->call(args); + } + +private: + llvm::Optional runtimeLibPath; + llvm::function_ref llvmOptPipeline; +}; + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/LambdaSupport.h b/compiler/include/concretelang/Support/LambdaSupport.h new file mode 100644 index 000000000..40361034a --- /dev/null +++ b/compiler/include/concretelang/Support/LambdaSupport.h @@ -0,0 +1,318 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT +#define CONCRETELANG_SUPPORT_LAMBDASUPPORT + +#include "boost/outcome.h" + +#include "concretelang/Support/LambdaArgument.h" + +#include "concretelang/ClientLib/ClientLambda.h" +#include "concretelang/ClientLib/ClientParameters.h" +#include "concretelang/ClientLib/KeySetCache.h" +#include "concretelang/ClientLib/Serializers.h" +#include "concretelang/Common/Error.h" +#include "concretelang/ServerLib/ServerLambda.h" + +namespace mlir { +namespace concretelang { + +namespace clientlib = ::concretelang::clientlib; + +namespace { +// Generic function template as well as specializations of +// `typedResult` must be declared at namespace scope due to return +// type template specialization + +// Helper function for `JitCompilerEngine::Lambda::operator()` +// implementing type-dependent preparation of the result. +template +llvm::Expected typedResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result); + +// Specialization of `typedResult()` for scalar results, forwarding +// scalar value to caller +template <> +inline llvm::Expected typedResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (!clearResult.has_value()) { + return StreamStringError("typedResult cannot get clear text vector") + << clearResult.error().mesg; + } + if (clearResult.value().size() != 1) { + return StreamStringError("typedResult expect only one value but got ") + << clearResult.value().size(); + } + return clearResult.value()[0]; +} + +template +inline llvm::Expected> +typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (!clearResult.has_value()) { + return StreamStringError("typedVectorResult cannot get clear text vector") + << clearResult.error().mesg; + } + return std::move(clearResult.value()); +} + +// Specializations of `typedResult()` for vector results, initializing +// an `std::vector` of the right size with the results and forwarding +// it to the caller with move semantics. +// +// Cannot factor out into a template template inline +// llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due +// to ambiguity with scalar template +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } +// template <> +// inline llvm::Expected> +// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { +// return typedVectorResult(keySet, result); +// } +template <> +inline llvm::Expected> +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + return typedVectorResult(keySet, result); +} + +template +llvm::Expected> +buildTensorLambdaResult(clientlib::KeySet &keySet, + clientlib::PublicResult &result) { + llvm::Expected> tensorOrError = + typedResult>(keySet, result); + + if (auto err = tensorOrError.takeError()) + return std::move(err); + std::vector tensorDim(result.buffers[0].sizes.begin(), + result.buffers[0].sizes.end() - 1); + + return std::make_unique>>( + *tensorOrError, tensorDim); +} + +// pecialization of `typedResult()` for a single result wrapped into +// a `LambdaArgument`. +template <> +inline llvm::Expected> +typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { + auto gate = keySet.outputGate(0); + // scalar case + if (gate.shape.dimensions.empty()) { + auto clearResult = result.asClearTextVector(keySet, 0); + if (clearResult.has_error()) { + return StreamStringError("typedResult: ") << clearResult.error().mesg; + } + auto res = clearResult.value()[0]; + + return std::make_unique>(res); + } + // tensor case + // auto width = gate.shape.width; + + // if (width > 32) + return buildTensorLambdaResult(keySet, result); + // else if (width > 16) + // return buildTensorLambdaResult(keySet, result); + // else if (width > 8) + // return buildTensorLambdaResult(keySet, result); + // else if (width <= 8) + // return buildTensorLambdaResult(keySet, result); + + // return StreamStringError("Cannot handle scalars with more than 64 bits"); +} + +} // namespace + +// Adaptor class that push arguments specified as instances of +// `LambdaArgument` to `clientlib::EncryptedArguments`. +class LambdaArgumentAdaptor { +public: + // Checks if the argument `arg` is an plaintext / encrypted integer + // argument or a plaintext / encrypted tensor argument with a + // backing integer type `IntT` and push the argument to `encryptedArgs`. + // + // Returns `true` if `arg` has one of the types above and its value + // was successfully added to `encryptedArgs`, `false` if none of the types + // matches or an error if a type matched, but adding the argument to + // `encryptedArgs` failed. + template + static inline llvm::Expected + tryAddArg(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { + if (auto ila = arg.dyn_cast>()) { + auto res = encryptedArgs.pushArg(ila->getValue(), keySet); + if (!res.has_value()) { + return StreamStringError(res.error().mesg); + } else { + return true; + } + } else if (auto tla = arg.dyn_cast< + TensorLambdaArgument>>()) { + auto res = + encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet); + if (!res.has_value()) { + return StreamStringError(res.error().mesg); + } else { + return true; + } + } + return false; + } + + // Recursive case for `tryAddArg(...)` + template + static inline llvm::Expected + tryAddArg(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { + llvm::Expected successOrError = + tryAddArg(encryptedArgs, arg, keySet); + + if (!successOrError) + return successOrError.takeError(); + + if (successOrError.get() == false) + return tryAddArg(encryptedArgs, arg, keySet); + else + return true; + } + + // Attempts to push a single argument `arg` to `encryptedArgs`. Returns an + // error if either the argument type is unsupported or if the argument types + // is supported, but adding it to `encryptedArgs` failed. + static inline llvm::Error + addArgument(clientlib::EncryptedArguments &encryptedArgs, + const LambdaArgument &arg, clientlib::KeySet &keySet) { + // Try the supported integer types; size_t needs explicit + // treatment, since it may alias none of the fixed size integer + // types + llvm::Expected successOrError = + LambdaArgumentAdaptor::tryAddArg(encryptedArgs, arg, keySet); + + if (!successOrError) + return successOrError.takeError(); + + if (successOrError.get() == false) + return StreamStringError("Unknown argument type"); + else + return llvm::Error::success(); + } + + /// Encrypts and build public arguments from lambda arguments + static llvm::Expected> + exportArguments(llvm::ArrayRef args, + clientlib::ClientParameters clientParameters, + clientlib::KeySet &keySet) { + + auto encryptedArgs = clientlib::EncryptedArguments::empty(); + for (auto arg : args) { + if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg, + keySet)) { + return std::move(err); + } + } + auto check = encryptedArgs->checkAllArgs(keySet); + if (check.has_error()) { + return StreamStringError(check.error().mesg); + } + auto publicArguments = encryptedArgs->exportPublicArguments( + clientParameters, keySet.runtimeContext()); + if (publicArguments.has_error()) { + return StreamStringError(publicArguments.error().mesg); + } + return std::move(publicArguments.value()); + } +}; + +template class LambdaSupport { + +public: + virtual ~LambdaSupport() {} + + /// Compile the mlir program and produces a compilation result if succeed. + llvm::Expected> virtual compile( + llvm::SourceMgr &program, std::string funcname = "main"); + + llvm::Expected> + compile(llvm::StringRef program, std::string funcname = "main") { + return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname); + } + + llvm::Expected> + compile(std::unique_ptr program, + std::string funcname = "main") { + llvm::SourceMgr sm; + sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc()); + return compile(sm, funcname); + } + + /// Load the server lambda from the compilation result. + llvm::Expected virtual loadServerLambda(CompilationResult &result); + + /// Load the client parameters from the compilation result. + llvm::Expected virtual loadClientParameters( + CompilationResult &result); + + /// Call the lambda with the public arguments. + llvm::Expected> virtual serverCall( + Lambda lambda, clientlib::PublicArguments &args); + + /// Build the client KeySet from the client parameters. + static llvm::Expected> + keySet(clientlib::ClientParameters clientParameters, + llvm::Optional cache) { + std::shared_ptr cachePtr; + if (cache.hasValue()) { + cachePtr = std::make_shared(cache.getValue()); + } + auto keySet = + clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0); + if (keySet.has_error()) { + return StreamStringError(keySet.error().mesg); + } + return std::move(keySet.value()); + } + + static llvm::Expected> + exportArguments(clientlib::ClientParameters clientParameters, + clientlib::KeySet &keySet, + llvm::ArrayRef args) { + return LambdaArgumentAdaptor::exportArguments(args, clientParameters, + keySet); + } + + template + static llvm::Expected + call(Lambda lambda, clientlib::PublicArguments &publicArguments) { + // Call the lambda + auto publicResult = LambdaSupport().serverCall( + lambda, publicArguments); + if (auto err = publicResult.takeError()) { + return std::move(err); + } + + // Decrypt the result + return typedResult(keySet, **publicResult); + } +}; + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/LibraryLambdaSupport.h b/compiler/include/concretelang/Support/LibraryLambdaSupport.h new file mode 100644 index 000000000..82edc484a --- /dev/null +++ b/compiler/include/concretelang/Support/LibraryLambdaSupport.h @@ -0,0 +1,103 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT +#define CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace mlir { +namespace concretelang { + +namespace clientlib = ::concretelang::clientlib; +namespace serverlib = ::concretelang::serverlib; + +/// LibraryCompilationResult is the result of a compilation to a library. +struct LibraryCompilationResult { + /// The output path where the compilation artifact has been generated. + std::string libraryPath; + std::string funcName; +}; + +class LibraryLambdaSupport + : public LambdaSupport { + +public: + LibraryLambdaSupport(std::string outputPath = "/tmp/toto") + : outputPath(outputPath) {} + + llvm::Expected> + compile(llvm::SourceMgr &program, std::string funcname = "main") override { + // Setup the compiler engine + auto context = CompilationContext::createShared(); + concretelang::CompilerEngine engine(context); + engine.setClientParametersFuncName(funcname); + + // Compile to a library + auto library = engine.compile(program, outputPath); + if (auto err = library.takeError()) { + return std::move(err); + } + + auto result = std::make_unique(); + result->libraryPath = outputPath; + result->funcName = funcname; + return std::move(result); + } + using LambdaSupport::compile; + + /// Load the server lambda from the compilation result. + llvm::Expected + loadServerLambda(LibraryCompilationResult &result) override { + auto lambda = + serverlib::ServerLambda::load(result.funcName, result.libraryPath); + if (lambda.has_error()) { + return StreamStringError(lambda.error().mesg); + } + return lambda.value(); + } + + /// Load the client parameters from the compilation result. + llvm::Expected + loadClientParameters(LibraryCompilationResult &result) override { + auto path = ClientParameters::getClientParametersPath(result.libraryPath); + auto params = ClientParameters::load(path); + if (params.has_error()) { + return StreamStringError(params.error().mesg); + } + auto param = llvm::find_if(params.value(), [&](ClientParameters param) { + return param.functionName == result.funcName; + }); + if (param == params.value().end()) { + return StreamStringError("ClientLambda: cannot find function(") + << result.funcName << ") in client parameters path(" << path + << ")"; + } + return *param; + } + + /// Call the lambda with the public arguments. + llvm::Expected> + serverCall(serverlib::ServerLambda lambda, + clientlib::PublicArguments &args) override { + return lambda.call(args); + } + +private: + std::string outputPath; +}; + +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 6cad01739..fdab75984 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(ConcretelangSupport Jit.cpp CompilerEngine.cpp JitCompilerEngine.cpp + JitLambdaSupport.cpp LambdaArgument.cpp V0Parameters.cpp V0Curves.cpp @@ -32,4 +33,5 @@ add_mlir_library(ConcretelangSupport ConcretelangRuntime ConcretelangClientLib + ConcretelangServerLib ) diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index d1eefcd08..9b7387de7 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -383,19 +383,17 @@ llvm::Expected CompilerEngine::compile(std::unique_ptr buffer, Target target, OptionalLib lib) { llvm::SourceMgr sm; - sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); return this->compile(sm, target, lib); } -template llvm::Expected -CompilerEngine::compile(std::vector inputs, std::string libraryPath) { +CompilerEngine::compile(std::vector inputs, + std::string libraryPath) { using Library = mlir::concretelang::CompilerEngine::Library; auto outputLib = std::make_shared(libraryPath); auto target = CompilerEngine::Target::LIBRARY; - for (auto input : inputs) { auto compilation = compile(input, target, outputLib); if (!compilation) { @@ -403,6 +401,24 @@ CompilerEngine::compile(std::vector inputs, std::string libraryPath) { << llvm::toString(compilation.takeError()); } } + if (auto err = outputLib->emitArtifacts()) { + return StreamStringError("Can't emit artifacts: ") + << llvm::toString(std::move(err)); + } + return *outputLib.get(); +} + +llvm::Expected +CompilerEngine::compile(llvm::SourceMgr &sm, std::string libraryPath) { + using Library = mlir::concretelang::CompilerEngine::Library; + auto outputLib = std::make_shared(libraryPath); + auto target = CompilerEngine::Target::LIBRARY; + + auto compilation = compile(sm, target, outputLib); + if (!compilation) { + return StreamStringError("Can't compile: ") + << llvm::toString(compilation.takeError()); + } if (auto err = outputLib->emitArtifacts()) { return StreamStringError("Can't emit artifacts: ") @@ -411,11 +427,6 @@ CompilerEngine::compile(std::vector inputs, std::string libraryPath) { return *outputLib.get(); } -// explicit instantiation for a vector of string (for linking with lib/CAPI) -template llvm::Expected -CompilerEngine::compile(std::vector inputs, - std::string libraryPath); - /** Returns the path of the shared library */ std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) { return path + DOT_SHARED_LIB_EXT; diff --git a/compiler/lib/Support/JitLambdaSupport.cpp b/compiler/lib/Support/JitLambdaSupport.cpp new file mode 100644 index 000000000..1dba84f4c --- /dev/null +++ b/compiler/lib/Support/JitLambdaSupport.cpp @@ -0,0 +1,49 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#include + +namespace mlir { +namespace concretelang { + +llvm::Expected> +JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) { + // Setup the compiler engine + auto context = std::make_shared(); + concretelang::CompilerEngine engine(context); + + // We need client parameters to be generated + engine.setGenerateClientParameters(true); + engine.setClientParametersFuncName(funcname); + + // Compile to LLVM Dialect + auto compilationResult = + engine.compile(program, CompilerEngine::Target::LLVM_IR); + + if (auto err = compilationResult.takeError()) { + return std::move(err); + } + + // Compile from LLVM Dialect to JITLambda + auto mlirModule = compilationResult.get().mlirModuleRef->get(); + auto lambda = concretelang::JITLambda::create( + funcname, mlirModule, llvmOptPipeline, runtimeLibPath); + if (auto err = lambda.takeError()) { + return std::move(err); + } + + if (!compilationResult.get().clientParameters.hasValue()) { + // i.e. that should not occurs + return StreamStringError("No client parameters has been generated"); + } + auto result = std::make_unique(); + result->lambda = std::move(*lambda); + result->clientParameters = + compilationResult.get().clientParameters.getValue(); + return std::move(result); +} + +} // namespace concretelang +} // namespace mlir \ No newline at end of file diff --git a/compiler/tests/unittest/EndToEndFixture.cpp b/compiler/tests/unittest/EndToEndFixture.cpp index b0885671f..48137131c 100644 --- a/compiler/tests/unittest/EndToEndFixture.cpp +++ b/compiler/tests/unittest/EndToEndFixture.cpp @@ -61,8 +61,8 @@ valueDescriptionToLambdaArgument(ValueDescription desc) { } llvm::Error checkResult(ScalarDesc &desc, - mlir::concretelang::LambdaArgument *res) { - auto res64 = res->dyn_cast>(); + mlir::concretelang::LambdaArgument &res) { + auto res64 = res.dyn_cast>(); if (res64 == nullptr) { return StreamStringError("invocation result is not a scalar"); } @@ -85,7 +85,7 @@ checkTensorResult(TensorDescription &desc, << resShape.size() << " expected " << desc.shape.size(); } for (size_t i = 0; i < desc.shape.size(); i++) { - if ((uint64_t)resShape[i] != desc.shape[i]) { + if (resShape[i] != desc.shape[i]) { return StreamStringError("shape differs at pos ") << i << ", got " << resShape[i] << " expected " << desc.shape[i]; } @@ -112,37 +112,36 @@ checkTensorResult(TensorDescription &desc, } llvm::Error checkResult(TensorDescription &desc, - mlir::concretelang::LambdaArgument *res) { + mlir::concretelang::LambdaArgument &res) { switch (desc.width) { case 8: return checkTensorResult( - desc, res->dyn_cast>>()); case 16: return checkTensorResult( - desc, res->dyn_cast>>()); case 32: return checkTensorResult( - desc, res->dyn_cast>>()); case 64: return checkTensorResult( - desc, res->dyn_cast>>()); default: return StreamStringError("Unsupported width"); } } -llvm::Error -checkResult(ValueDescription &desc, - std::unique_ptr &res) { +llvm::Error checkResult(ValueDescription &desc, + mlir::concretelang::LambdaArgument &res) { switch (desc.tag) { case ValueDescription::SCALAR: - return checkResult(desc.scalar, res.get()); + return checkResult(desc.scalar, res); case ValueDescription::TENSOR: - return checkResult(desc.tensor, res.get()); + return checkResult(desc.tensor, res); } assert(false); } @@ -190,7 +189,7 @@ template <> struct llvm::yaml::MappingTraits { } }; -LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc); +LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc) std::vector loadEndToEndDesc(std::string path) { std::ifstream file(path); diff --git a/compiler/tests/unittest/end_to_end_jit_fhe.cc b/compiler/tests/unittest/end_to_end_jit_fhe.cc index 66b14f6eb..40afd5c37 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhe.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhe.cc @@ -4,53 +4,88 @@ #include #include "EndToEndFixture.h" +#include "concretelang/Support/JitLambdaSupport.h" +#include "concretelang/Support/LibraryLambdaSupport.h" -class EndToEndJitTest : public testing::TestWithParam {}; - -TEST_P(EndToEndJitTest, compile_and_run) { - EndToEndDesc desc = GetParam(); - - // Compile program - // mlir::concretelang::JitCompilerEngine::Lambda lambda = - checkedJit(lambda, desc.program); - - // Prepare arguments - for (auto test : desc.tests) { - std::vector inputArguments; - inputArguments.reserve(test.inputs.size()); - for (auto input : test.inputs) { - auto arg = valueDescriptionToLambdaArgument(input); - ASSERT_EXPECTED_SUCCESS(arg); - inputArguments.push_back(arg.get()); - } - - // Call the lambda - auto res = - lambda.operator()>( - llvm::ArrayRef( - inputArguments)); - ASSERT_EXPECTED_SUCCESS(res); - if (test.outputs.size() != 1) { - FAIL() << "Only one result function are supported."; - } - ASSERT_LLVM_ERROR(checkResult(test.outputs[0], res.get())); - - // Free arguments - for (auto arg : inputArguments) { - delete arg; - } - } -} - -#define INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(prefix, path) \ - namespace prefix { \ - auto valuesVector = loadEndToEndDesc(path); \ - auto values = testing::ValuesIn>(valuesVector); \ - INSTANTIATE_TEST_SUITE_P(prefix, EndToEndJitTest, values, \ - printEndToEndDesc); \ +// Macro to define and end to end TestSuite that run test thanks the +// LambdaSupport according a EndToEndDesc +#define INSTANTIATE_END_TO_END_COMPILE_AND_RUN(TestSuite, LambdaSupport) \ + TEST_P(TestSuite, compile_and_run) { \ + \ + auto desc = GetParam(); \ + \ + LambdaSupport support; \ + \ + /* 1 - Compile the program */ \ + auto compilationResult = support.compile(desc.program); \ + ASSERT_EXPECTED_SUCCESS(compilationResult); \ + \ + /* 2 - Load the client parameters and build the keySet */ \ + auto clientParameters = support.loadClientParameters(**compilationResult); \ + ASSERT_EXPECTED_SUCCESS(clientParameters); \ + \ + auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); \ + ASSERT_EXPECTED_SUCCESS(keySet); \ + \ + /* 3 - Load the server lambda */ \ + auto serverLambda = support.loadServerLambda(**compilationResult); \ + ASSERT_EXPECTED_SUCCESS(serverLambda); \ + \ + /* For each test entries */ \ + for (auto test : desc.tests) { \ + std::vector inputArguments; \ + inputArguments.reserve(test.inputs.size()); \ + for (auto input : test.inputs) { \ + auto arg = valueDescToLambdaArgument(input); \ + ASSERT_EXPECTED_SUCCESS(arg); \ + inputArguments.push_back(arg.get()); \ + } \ + /* 4 - Create the public arguments */ \ + auto publicArguments = support.exportArguments( \ + *clientParameters, **keySet, inputArguments); \ + ASSERT_EXPECTED_SUCCESS(publicArguments); \ + \ + /* 5 - Call the server lambda */ \ + auto publicResult = \ + support.serverCall(*serverLambda, **publicArguments); \ + ASSERT_EXPECTED_SUCCESS(publicResult); \ + \ + /* 6 - Decrypt the public result */ \ + auto result = mlir::concretelang::typedResult< \ + std::unique_ptr>( \ + **keySet, **publicResult); \ + \ + ASSERT_EXPECTED_SUCCESS(result); \ + \ + for (auto arg : inputArguments) { \ + delete arg; \ + } \ + } \ } -INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE( - FHE, "tests/unittest/end_to_end_fhe.yaml") -INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE( - EncryptedTensor, "tests/unittest/end_to_end_encrypted_tensor.yaml") +#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE(prefix, suite, \ + lambdasupport, path) \ + namespace prefix##suite { \ + auto valuesVector = loadEndToEndDesc(path); \ + auto values = testing::ValuesIn>(valuesVector); \ + INSTANTIATE_TEST_SUITE_P(prefix, suite, values, printEndToEndDesc); \ + } + +#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(suite, \ + lambdasupport) \ + \ + class suite : public testing::TestWithParam {}; \ + INSTANTIATE_END_TO_END_COMPILE_AND_RUN(suite, lambdasupport) \ + INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \ + FHE, suite, lambdasupport, "tests/unittest/end_to_end_fhe.yaml") \ + INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \ + EncryptedTensor, suite, lambdasupport, \ + "tests/unittest/end_to_end_encrypted_tensor.yaml") + +/// Instantiate the test suite for Jit +INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES( + JitTest, mlir::concretelang::JitLambdaSupport) + +/// Instantiate the test suite for Jit +INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES( + LibraryTest, mlir::concretelang::LibraryLambdaSupport) \ No newline at end of file