// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H #define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H #include "concretelang/ClientLib/KeySetCache.h" #include #include #include #include #include namespace mlir { namespace concretelang { using ::concretelang::clientlib::KeySetCache; namespace clientlib = ::concretelang::clientlib; namespace { // Generic function template as well as specializations of // `typedResult` must be declared at namespace scope due to return // type template specialization // Helper function for `JitCompilerEngine::Lambda::operator()` // implementing type-dependent preparation of the result. template llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); // Specialization of `typedResult()` for scalar results, forwarding // scalar value to caller template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto clearResult = result.asClearTextVector(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedResult cannot get clear text vector") << clearResult.error().mesg; } if (clearResult.value().size() != 1) { return StreamStringError("typedResult expect only one value but got ") << clearResult.value().size(); } return clearResult.value()[0]; } template inline llvm::Expected> typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto clearResult = result.asClearTextVector(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedVectorResult cannot get clear text vector") << clearResult.error().mesg; } return std::move(clearResult.value()); } // Specializations of `typedResult()` for vector results, initializing // an `std::vector` of the right size with the results and forwarding // it to the caller with move semantics. // // Cannot factor out into a template template inline // llvm::Expected> // typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due // to ambiguity with scalar template // template <> // inline llvm::Expected> // typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { // return typedVectorResult(keySet, result); // } // template <> // inline llvm::Expected> // typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { // return typedVectorResult(keySet, result); // } // template <> // inline llvm::Expected> // typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { // return typedVectorResult(keySet, result); // } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template llvm::Expected> buildTensorLambdaResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { llvm::Expected> tensorOrError = typedResult>(keySet, result); if (auto err = tensorOrError.takeError()) return std::move(err); std::vector tensorDim(result.buffers[0].sizes.begin(), result.buffers[0].sizes.end() - 1); return std::make_unique>>( *tensorOrError, tensorDim); } // Specialization of `typedResult()` for a single result wrapped into // a `LambdaArgument`. template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto gate = keySet.outputGate(0); // scalar case if (gate.shape.dimensions.empty()) { auto clearResult = result.asClearTextVector(keySet, 0); if (clearResult.has_error()) { return StreamStringError("typedResult: ") << clearResult.error().mesg; } auto res = clearResult.value()[0]; return std::make_unique>(res); } // tensor case // auto width = gate.shape.width; // if (width > 32) return buildTensorLambdaResult(keySet, result); // else if (width > 16) // return buildTensorLambdaResult(keySet, result); // else if (width > 8) // return buildTensorLambdaResult(keySet, result); // else if (width <= 8) // return buildTensorLambdaResult(keySet, result); // return StreamStringError("Cannot handle scalars with more than 64 bits"); } } // namespace // Adaptor class that push arguments specified as instances of // `LambdaArgument` to `clientlib::EncryptedArguments`. class JITLambdaArgumentAdaptor { public: // Checks if the argument `arg` is an plaintext / encrypted integer // argument or a plaintext / encrypted tensor argument with a // backing integer type `IntT` and push the argument to `encryptedArgs`. // // Returns `true` if `arg` has one of the types above and its value // was successfully added to `encryptedArgs`, `false` if none of the types // matches or an error if a type matched, but adding the argument to // `encryptedArgs` failed. template static inline llvm::Expected tryAddArg(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { if (auto ila = arg.dyn_cast>()) { auto res = encryptedArgs.pushArg(ila->getValue(), keySet); if (!res.has_value()) { return StreamStringError(res.error().mesg); } else { return true; } } else if (auto tla = arg.dyn_cast< TensorLambdaArgument>>()) { auto res = encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet); if (!res.has_value()) { return StreamStringError(res.error().mesg); } else { return true; } } return false; } // Recursive case for `tryAddArg(...)` template static inline llvm::Expected tryAddArg(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { llvm::Expected successOrError = tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) return successOrError.takeError(); if (successOrError.get() == false) return tryAddArg(encryptedArgs, arg, keySet); else return true; } // Attempts to push a single argument `arg` to `encryptedArgs`. Returns an // error if either the argument type is unsupported or if the argument types // is supported, but adding it to `encryptedArgs` failed. static inline llvm::Error addArgument(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { // Try the supported integer types; size_t needs explicit // treatment, since it may alias none of the fixed size integer // types llvm::Expected successOrError = JITLambdaArgumentAdaptor::tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) return successOrError.takeError(); if (successOrError.get() == false) return StreamStringError("Unknown argument type"); else return llvm::Error::success(); } }; // A compiler engine that JIT-compiles a source and produces a lambda // object directly invocable through its call operator. class JitCompilerEngine : public CompilerEngine { public: // Wrapper class around `JITLambda` and `JITLambda::Argument` that // allows for direct invocation of a compiled function through // `operator ()`. class Lambda { public: Lambda(Lambda &&other) : innerLambda(std::move(other.innerLambda)), keySet(std::move(other.keySet)), compilationContext(other.compilationContext), clientParameters(other.clientParameters) {} Lambda(std::shared_ptr compilationContext, std::unique_ptr lambda, std::unique_ptr keySet, clientlib::ClientParameters clientParameters) : innerLambda(std::move(lambda)), keySet(std::move(keySet)), compilationContext(compilationContext), clientParameters(clientParameters) {} // Returns the number of arguments required for an invocation of // the lambda size_t getNumArguments() { return this->keySet->numInputs(); } // Returns the number of results an invocation of the lambda // produces size_t getNumResults() { return this->keySet->numOutputs(); } // Invocation with an dynamic list of arguments of different // types, specified as `LambdaArgument`s template llvm::Expected operator()(llvm::ArrayRef lambdaArgs) { // Encrypt the arguments auto encryptedArgs = clientlib::EncryptedArguments::empty(); for (size_t i = 0; i < lambdaArgs.size(); i++) { if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( *encryptedArgs, *lambdaArgs[i], *this->keySet)) { return std::move(err); } } auto check = encryptedArgs->checkAllArgs(*this->keySet); if (check.has_error()) { return StreamStringError(check.error().mesg); } // Export as public arguments auto publicArguments = encryptedArgs->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } // Call the lambda auto publicResult = this->innerLambda->call(*publicArguments.value()); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } // Invocation with an array of arguments of the same type template llvm::Expected operator()(const llvm::ArrayRef args) { // Encrypt the arguments auto encryptedArgs = clientlib::EncryptedArguments::empty(); for (size_t i = 0; i < args.size(); i++) { auto res = encryptedArgs->pushArg(args[i], *keySet); if (res.has_error()) { return StreamStringError(res.error().mesg); } } auto check = encryptedArgs->checkAllArgs(*this->keySet); if (check.has_error()) { return StreamStringError(check.error().mesg); } // Export as public arguments auto publicArguments = encryptedArgs->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } // Call the lambda auto publicResult = this->innerLambda->call(*publicArguments.value()); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } // Invocation with arguments of different types template llvm::Expected operator()(const Ts... ts) { // Encrypt the arguments auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, ts...); if (encryptedArgs.has_error()) { return StreamStringError(encryptedArgs.error().mesg); } // Export as public arguments auto publicArguments = encryptedArgs.value()->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } // Call the lambda auto publicResult = this->innerLambda->call(*publicArguments.value()); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } protected: std::unique_ptr innerLambda; std::unique_ptr keySet; std::shared_ptr compilationContext; const clientlib::ClientParameters clientParameters; }; JitCompilerEngine(std::shared_ptr compilationContext = CompilationContext::createShared(), unsigned int optimizationLevel = 3); /// Build a Lambda from a source MLIR, with `funcName` as entrypoint. /// Use runtimeLibPath as a shared library if specified. llvm::Expected buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(std::unique_ptr buffer, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); protected: llvm::Expected findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name); unsigned int optimizationLevel; }; } // namespace concretelang } // namespace mlir #endif