// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H #define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H #include "concretelang/ClientLib/KeySetCache.h" #include #include #include #include #include namespace mlir { namespace concretelang { namespace { // Generic function template as well as specializations of // `typedResult` must be declared at namespace scope due to return // type template specialization // Helper function for `JitCompilerEngine::Lambda::operator()` // implementing type-dependent preparation of the result. template llvm::Expected typedResult(JITLambda::Argument &arguments); // Specialization of `typedResult()` for scalar results, forwarding // scalar value to caller template <> inline llvm::Expected typedResult(JITLambda::Argument &arguments) { uint64_t res = 0; if (auto err = arguments.getResult(0, res)) return StreamStringError() << "Cannot retrieve result:" << err; return res; } template inline llvm::Expected> typedVectorResult(JITLambda::Argument &arguments) { llvm::Expected n = arguments.getResultVectorSize(0); if (auto err = n.takeError()) return std::move(err); std::vector res(*n); if (auto err = arguments.getResult(0, res.data(), res.size())) return StreamStringError() << "Cannot retrieve result:" << err; return std::move(res); } // Specializations of `typedResult()` for vector results, initializing // an `std::vector` of the right size with the results and forwarding // it to the caller with move semantics. // // Cannot factor out into a template template inline // llvm::Expected> // typedResult(JITLambda::Argument &arguments); due to ambiguity with // scalar template template <> inline llvm::Expected> typedResult(JITLambda::Argument &arguments) { return std::move(typedVectorResult(arguments)); } template <> inline llvm::Expected> typedResult(JITLambda::Argument &arguments) { return std::move(typedVectorResult(arguments)); } template <> inline llvm::Expected> typedResult(JITLambda::Argument &arguments) { return std::move(typedVectorResult(arguments)); } template <> inline llvm::Expected> typedResult(JITLambda::Argument &arguments) { return std::move(typedVectorResult(arguments)); } template llvm::Expected> buildTensorLambdaResult(JITLambda::Argument &arguments) { llvm::Expected> tensorOrError = typedResult>(arguments); if (!tensorOrError) return std::move(tensorOrError.takeError()); llvm::Expected> tensorDimOrError = arguments.getResultDimensions(0); if (!tensorDimOrError) return std::move(tensorDimOrError.takeError()); return std::move(std::make_unique>>( *tensorOrError, *tensorDimOrError)); } // Specialization of `typedResult()` for a single result wrapped into // a `LambdaArgument`. template <> inline llvm::Expected> typedResult(JITLambda::Argument &arguments) { llvm::Expected resTy = arguments.getResultType(0); if (!resTy) return std::move(resTy.takeError()); switch (*resTy) { case JITLambda::Argument::ResultType::SCALAR: { uint64_t res; if (llvm::Error err = arguments.getResult(0, res)) return std::move(err); return std::move(std::make_unique>(res)); } case JITLambda::Argument::ResultType::TENSOR: { llvm::Expected width = arguments.getResultWidth(0); if (!width) return std::move(width.takeError()); if (*width > 64) return StreamStringError("Cannot handle scalars with more than 64 bits"); if (*width > 32) return buildTensorLambdaResult(arguments); else if (*width > 16) return buildTensorLambdaResult(arguments); else if (*width > 8) return buildTensorLambdaResult(arguments); else if (*width <= 8) return buildTensorLambdaResult(arguments); } } return StreamStringError("Unknown result type"); } // namespace // Adaptor class that adds arguments specified as instances of // `LambdaArgument` to `JitLambda::Argument`. class JITLambdaArgumentAdaptor { public: // Checks if the argument `arg` is an plaintext / encrypted integer // argument or a plaintext / encrypted tensor argument with a // backing integer type `IntT` and adds the argument to `jla` at // position `pos`. // // Returns `true` if `arg` has one of the types above and its value // was successfully added to `jla`, `false` if none of the types // matches or an error if a type matched, but adding the argument to // `jla` failed. template static inline llvm::Expected tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { if (auto ila = arg.dyn_cast>()) { if (llvm::Error err = jla.setArg(pos, ila->getValue())) return std::move(err); else return true; } else if (auto tla = arg.dyn_cast< TensorLambdaArgument>>()) { if (llvm::Error err = jla.setArg(pos, tla->getValue(), tla->getDimensions())) return std::move(err); else return true; } return false; } // Recursive case for `tryAddArg(...)` template static inline llvm::Expected tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { llvm::Expected successOrError = tryAddArg(jla, pos, arg); if (!successOrError) return std::move(successOrError.takeError()); if (successOrError.get() == false) return tryAddArg(jla, pos, arg); else return true; } // Attempts to add a single argument `arg` to `jla` at position // `pos`. Returns an error if either the argument type is // unsupported or if the argument types is supported, but adding it // to `jla` failed. static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { // Try the supported integer types; size_t needs explicit // treatment, since it may alias none of the fixed size integer // types llvm::Expected successOrError = JITLambdaArgumentAdaptor::tryAddArg(jla, pos, arg); if (!successOrError) return std::move(successOrError.takeError()); if (successOrError.get() == false) return StreamStringError("Unknown argument type"); else return llvm::Error::success(); } }; } // namespace // A compiler engine that JIT-compiles a source and produces a lambda // object directly invocable through its call operator. class JitCompilerEngine : public CompilerEngine { public: // Wrapper class around `JITLambda` and `JITLambda::Argument` that // allows for direct invocation of a compiled function through // `operator ()`. class Lambda { public: Lambda(Lambda &&other) : innerLambda(std::move(other.innerLambda)), keySet(std::move(other.keySet)), compilationContext(other.compilationContext) {} Lambda(std::shared_ptr compilationContext, std::unique_ptr lambda, std::unique_ptr keySet) : innerLambda(std::move(lambda)), keySet(std::move(keySet)), compilationContext(compilationContext) {} // Returns the number of arguments required for an invocation of // the lambda size_t getNumArguments() { return this->keySet->numInputs(); } // Returns the number of results an invocation of the lambda // produces size_t getNumResults() { return this->keySet->numOutputs(); } // Invocation with an dynamic list of arguments of different // types, specified as `LambdaArgument`s template llvm::Expected operator()(llvm::ArrayRef lambdaArgs) { // Create the arguments of the JIT lambda llvm::Expected> argsOrErr = mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); if (llvm::Error err = argsOrErr.takeError()) return StreamStringError("Could not create lambda arguments"); // Set the arguments std::unique_ptr arguments = std::move(argsOrErr.get()); for (size_t i = 0; i < lambdaArgs.size(); i++) { if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( *arguments, i, *lambdaArgs[i])) { return std::move(err); } } // Invoke the lambda if (auto err = this->innerLambda->invoke(*arguments)) return StreamStringError() << "Cannot invoke lambda:" << err; return std::move(typedResult(*arguments)); } // Invocation with an array of arguments of the same type template llvm::Expected operator()(const llvm::ArrayRef args) { // Create the arguments of the JIT lambda llvm::Expected> argsOrErr = mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); if (llvm::Error err = argsOrErr.takeError()) return StreamStringError("Could not create lambda arguments"); // Set the arguments std::unique_ptr arguments = std::move(argsOrErr.get()); for (size_t i = 0; i < args.size(); i++) { if (auto err = arguments->setArg(i, args[i])) { return StreamStringError() << "Cannot push argument " << i << ": " << err; } } // Invoke the lambda if (auto err = this->innerLambda->invoke(*arguments)) return StreamStringError() << "Cannot invoke lambda:" << err; return std::move(typedResult(*arguments)); } // Invocation with arguments of different types template llvm::Expected operator()(const Ts... ts) { // Create the arguments of the JIT lambda llvm::Expected> argsOrErr = mlir::concretelang::JITLambda::Argument::create(*this->keySet.get()); if (llvm::Error err = argsOrErr.takeError()) return StreamStringError("Could not create lambda arguments"); // Set the arguments std::unique_ptr arguments = std::move(argsOrErr.get()); if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...)) return std::move(err); // Invoke the lambda if (auto err = this->innerLambda->invoke(*arguments)) return StreamStringError() << "Cannot invoke lambda:" << err; return std::move(typedResult(*arguments)); } protected: template inline llvm::Error addArgs(JITLambda::Argument *jitArgs) { // base case -- nothing to do return llvm::Error::success(); } // Recursive case for scalars: extract first scalar argument from // parameter pack and forward rest template inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg, Ts... remainder) { if (auto err = jitArgs->setArg(pos, arg)) { return StreamStringError() << "Cannot push scalar argument " << pos << ": " << err; } return this->addArgs(jitArgs, remainder...); } // Recursive case for tensors: extract pointer and size from // parameter pack and forward rest template inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg, size_t size, Ts... remainder) { if (auto err = jitArgs->setArg(pos, arg, size)) { return StreamStringError() << "Cannot push tensor argument " << pos << ": " << err; } return this->addArgs(jitArgs, remainder...); } std::unique_ptr innerLambda; std::unique_ptr keySet; std::shared_ptr compilationContext; }; JitCompilerEngine(std::shared_ptr compilationContext = CompilationContext::createShared(), unsigned int optimizationLevel = 3); /// Build a Lambda from a source MLIR, with `funcName` as entrypoint. /// Use runtimeLibPath as a shared library if specified. llvm::Expected buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(std::unique_ptr buffer, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); llvm::Expected buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main", llvm::Optional cachePath = {}, llvm::Optional runtimeLibPath = {}); protected: llvm::Expected findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name); unsigned int optimizationLevel; }; } // namespace concretelang } // namespace mlir #endif