// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT #define CONCRETELANG_SUPPORT_LAMBDASUPPORT #include "boost/outcome.h" #include "concretelang/Support/LambdaArgument.h" #include "concretelang/ClientLib/ClientLambda.h" #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/ClientLib/Serializers.h" #include "concretelang/Common/Error.h" #include "concretelang/ServerLib/ServerLambda.h" namespace mlir { namespace concretelang { namespace clientlib = ::concretelang::clientlib; namespace { // Generic function template as well as specializations of // `typedResult` must be declared at namespace scope due to return // type template specialization /// Helper function for implementing type-dependent preparation of the result. template llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); template inline llvm::Expected typedScalarResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto clearResult = result.asClearTextScalar(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedResult cannot get clear text scalar") << clearResult.error().mesg; } return clearResult.value(); } /// Specializations of `typedResult()` for scalar results, forwarding /// scalar value to caller. template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template <> inline llvm::Expected typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedScalarResult(keySet, result); } template inline llvm::Expected> typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto clearResult = result.asClearTextVector(keySet, 0); if (!clearResult.has_value()) { return StreamStringError("typedVectorResult cannot get clear text vector") << clearResult.error().mesg; } return std::move(clearResult.value()); } /// Specializations of `typedResult()` for vector results, initializing /// an `std::vector` of the right size with the results and forwarding /// it to the caller with move semantics. /// Cannot factor out into a template template inline /// llvm::Expected> /// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due /// to ambiguity with scalar template template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { return typedVectorResult(keySet, result); } template llvm::Expected> buildTensorLambdaResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { llvm::Expected> tensorOrError = typedResult>(keySet, result); if (auto err = tensorOrError.takeError()) return std::move(err); auto tensorDim = result.asClearTextShape(0); if (tensorDim.has_error()) return StreamStringError(tensorDim.error().mesg); return std::make_unique>>( *tensorOrError, tensorDim.value()); } template llvm::Expected> buildScalarLambdaResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { llvm::Expected scalarOrError = typedResult(keySet, result); if (auto err = scalarOrError.takeError()) return std::move(err); return std::make_unique>(*scalarOrError); } /// pecialization of `typedResult()` for a single result wrapped into /// a `LambdaArgument`. template <> inline llvm::Expected> typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) { auto gate = keySet.outputGate(0); auto width = gate.shape.width; bool sign = gate.shape.sign; if (width > 64) return StreamStringError("Cannot handle values with more than 64 bits"); // By convention, decrypted integers are always 64 bits wide if (gate.isEncrypted()) width = 64; if (gate.shape.dimensions.empty()) { // scalar case if (width > 32) { return (sign) ? buildScalarLambdaResult(keySet, result) : buildScalarLambdaResult(keySet, result); } else if (width > 16) { return (sign) ? buildScalarLambdaResult(keySet, result) : buildScalarLambdaResult(keySet, result); } else if (width > 8) { return (sign) ? buildScalarLambdaResult(keySet, result) : buildScalarLambdaResult(keySet, result); } else if (width <= 8) { return (sign) ? buildScalarLambdaResult(keySet, result) : buildScalarLambdaResult(keySet, result); } } else { // tensor case if (width > 32) { return (sign) ? buildTensorLambdaResult(keySet, result) : buildTensorLambdaResult(keySet, result); } else if (width > 16) { return (sign) ? buildTensorLambdaResult(keySet, result) : buildTensorLambdaResult(keySet, result); } else if (width > 8) { return (sign) ? buildTensorLambdaResult(keySet, result) : buildTensorLambdaResult(keySet, result); } else if (width <= 8) { return (sign) ? buildTensorLambdaResult(keySet, result) : buildTensorLambdaResult(keySet, result); } } assert(false && "Cannot happen"); } } // namespace /// Adaptor class that push arguments specified as instances of /// `LambdaArgument` to `clientlib::EncryptedArguments`. class LambdaArgumentAdaptor { public: /// Checks if the argument `arg` is an plaintext / encrypted integer /// argument or a plaintext / encrypted tensor argument with a /// backing integer type `IntT` and push the argument to `encryptedArgs`. /// /// Returns `true` if `arg` has one of the types above and its value /// was successfully added to `encryptedArgs`, `false` if none of the types /// matches or an error if a type matched, but adding the argument to /// `encryptedArgs` failed. template static inline llvm::Expected tryAddArg(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { if (auto ila = arg.dyn_cast>()) { auto res = encryptedArgs.pushArg(ila->getValue(), keySet); if (!res.has_value()) { return StreamStringError(res.error().mesg); } else { return true; } } else if (auto tla = arg.dyn_cast< TensorLambdaArgument>>()) { auto res = encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet); if (!res.has_value()) { return StreamStringError(res.error().mesg); } else { return true; } } return false; } /// Recursive case for `tryAddArg(...)` template static inline llvm::Expected tryAddArg(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { llvm::Expected successOrError = tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) return successOrError.takeError(); if (successOrError.get() == false) return tryAddArg(encryptedArgs, arg, keySet); else return true; } /// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an /// error if either the argument type is unsupported or if the argument types /// is supported, but adding it to `encryptedArgs` failed. static inline llvm::Error addArgument(clientlib::EncryptedArguments &encryptedArgs, const LambdaArgument &arg, clientlib::KeySet &keySet) { // Try the supported integer types; size_t needs explicit // treatment, since it may alias none of the fixed size integer // types llvm::Expected successOrError = LambdaArgumentAdaptor::tryAddArg(encryptedArgs, arg, keySet); if (!successOrError) return successOrError.takeError(); if (successOrError.get() == false) return StreamStringError("Unknown argument type"); else return llvm::Error::success(); } /// Encrypts and build public arguments from lambda arguments static llvm::Expected> exportArguments(llvm::ArrayRef args, clientlib::ClientParameters clientParameters, clientlib::KeySet &keySet) { auto encryptedArgs = clientlib::EncryptedArguments::empty(); for (auto arg : args) { if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg, keySet)) { return std::move(err); } } auto check = encryptedArgs->checkAllArgs(keySet); if (check.has_error()) { return StreamStringError(check.error().mesg); } auto publicArguments = encryptedArgs->exportPublicArguments( clientParameters, keySet.runtimeContext()); if (publicArguments.has_error()) { return StreamStringError(publicArguments.error().mesg); } return std::move(publicArguments.value()); } }; template class LambdaSupport { public: typedef Lambda lambda; typedef CompilationResult compilationResult; virtual ~LambdaSupport() {} /// Compile the mlir program and produces a compilation result if succeed. llvm::Expected> virtual compile( llvm::SourceMgr &program, CompilationOptions options = CompilationOptions("main")) = 0; llvm::Expected> compile(llvm::StringRef program, CompilationOptions options = CompilationOptions("main")) { return compile(llvm::MemoryBuffer::getMemBuffer(program), options); } llvm::Expected> compile(std::unique_ptr program, CompilationOptions options = CompilationOptions("main")) { llvm::SourceMgr sm; sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc()); return compile(sm, options); } /// Load the server lambda from the compilation result. llvm::Expected virtual loadServerLambda( CompilationResult &result) = 0; /// Load the client parameters from the compilation result. llvm::Expected virtual loadClientParameters( CompilationResult &result) = 0; /// Load the compilation feedback from the compilation result. llvm::Expected virtual loadCompilationFeedback( CompilationResult &result) = 0; /// Call the lambda with the public arguments. llvm::Expected> virtual serverCall( Lambda lambda, clientlib::PublicArguments &args, clientlib::EvaluationKeys &evaluationKeys) = 0; /// Build the client KeySet from the client parameters. static llvm::Expected> keySet(clientlib::ClientParameters clientParameters, llvm::Optional cache) { std::shared_ptr cachePtr; if (cache.hasValue()) { cachePtr = std::make_shared(cache.getValue()); } auto keySet = clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0); if (keySet.has_error()) { return StreamStringError(keySet.error().mesg); } return std::move(keySet.value()); } static llvm::Expected> exportArguments(clientlib::ClientParameters clientParameters, clientlib::KeySet &keySet, llvm::ArrayRef args) { return LambdaArgumentAdaptor::exportArguments(args, clientParameters, keySet); } template static llvm::Expected call(Lambda lambda, clientlib::PublicArguments &publicArguments, clientlib::EvaluationKeys &evaluationKeys) { // Call the lambda auto publicResult = LambdaSupport().serverCall( lambda, publicArguments, evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } // Decrypt the result return typedResult(keySet, **publicResult); } }; template class ClientServer { public: static llvm::Expected create(llvm::StringRef program, CompilationOptions options = CompilationOptions("main"), llvm::Optional cache = {}, LambdaSupport support = LambdaSupport()) { auto compilationResult = support.compile(program, options); if (auto err = compilationResult.takeError()) { return std::move(err); } auto lambda = support.loadServerLambda(**compilationResult); if (auto err = lambda.takeError()) { return std::move(err); } auto clientParameters = support.loadClientParameters(**compilationResult); if (auto err = clientParameters.takeError()) { return std::move(err); } auto keySet = support.keySet(*clientParameters, cache); if (auto err = keySet.takeError()) { return std::move(err); } auto f = ClientServer(); f.lambda = *lambda; f.compilationResult = std::move(*compilationResult); f.keySet = std::move(*keySet); f.clientParameters = *clientParameters; f.support = support; return std::move(f); } template llvm::Expected operator()(llvm::ArrayRef args) { auto publicArguments = LambdaArgumentAdaptor::exportArguments( args, clientParameters, *this->keySet); if (auto err = publicArguments.takeError()) { return std::move(err); } auto evaluationKeys = this->keySet->evaluationKeys(); auto publicResult = support.serverCall(lambda, **publicArguments, evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } template llvm::Expected operator()(const llvm::ArrayRef args) { auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args); if (encryptedArgs.has_error()) { return StreamStringError(encryptedArgs.error().mesg); } auto publicArguments = encryptedArgs.value()->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (!publicArguments.has_value()) { return StreamStringError(publicArguments.error().mesg); } auto evaluationKeys = keySet->evaluationKeys(); auto publicResult = support.serverCall(lambda, *publicArguments.value(), evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } template llvm::Expected operator()(const Args... args) { auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args...); if (encryptedArgs.has_error()) { return StreamStringError(encryptedArgs.error().mesg); } auto publicArguments = encryptedArgs.value()->exportPublicArguments( clientParameters, keySet->runtimeContext()); if (publicArguments.has_error()) { return StreamStringError(publicArguments.error().mesg); } auto evaluationKeys = keySet->evaluationKeys(); auto publicResult = support.serverCall(lambda, *publicArguments.value(), evaluationKeys); if (auto err = publicResult.takeError()) { return std::move(err); } return typedResult(*keySet, **publicResult); } private: typename LambdaSupport::lambda lambda; std::unique_ptr compilationResult; std::unique_ptr keySet; clientlib::ClientParameters clientParameters; LambdaSupport support; }; } // namespace concretelang } // namespace mlir #endif