// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/homomorphizer/blob/master/LICENSE.txt for license // information. #include "llvm/Support/Error.h" #include #include #include #include #include namespace mlir { namespace concretelang { JitCompilerEngine::JitCompilerEngine( std::shared_ptr compilationContext, unsigned int optimizationLevel) : CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) { } // Returns the `LLVMFuncOp` operation in the compiled module with the // specified name. If no LLVMFuncOp with that name exists or if there // was no prior call to `compile()` resulting in an MLIR module in the // LLVM dialect, an error is returned. llvm::Expected JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) { auto funcOps = module.getOps(); auto funcOp = llvm::find_if( funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; }); if (funcOp == funcOps.end()) { return StreamStringError() << "Module does not contain function named '" << name.str() << "'"; } return *funcOp; } // Build a lambda from the function with the name given in // `funcName` from the sources in `buffer`. llvm::Expected JitCompilerEngine::buildLambda(std::unique_ptr buffer, llvm::StringRef funcName, llvm::Optional cache, llvm::Optional runtimeLibPath) { llvm::SourceMgr sm; sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); return this->buildLambda(sm, funcName, cache, runtimeLibPath); } // Build a lambda from the function with the name given in `funcName` // from the source string `s`. llvm::Expected JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName, llvm::Optional cache, llvm::Optional runtimeLibPath) { std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); llvm::Expected res = this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath); return std::move(res); } // Build a lambda from the function with the name given in // `funcName` from the sources managed by the source manager `sm`. llvm::Expected JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, llvm::Optional cache, llvm::Optional runtimeLibPath) { MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); this->setGenerateClientParameters(true); this->setClientParametersFuncName(funcName); // First, compile to LLVM Dialect llvm::Expected compResOrErr = this->compile(sm, Target::LLVM_IR); if (!compResOrErr) return std::move(compResOrErr.takeError()); auto compRes = std::move(compResOrErr.get()); mlir::ModuleOp module = compRes.mlirModuleRef->get(); // Locate function to JIT-compile llvm::Expected funcOrError = this->findLLVMFuncOp(compRes.mlirModuleRef->get(), funcName); if (!funcOrError) return StreamStringError() << "Cannot find function \"" << funcName << "\": " << std::move(funcOrError.takeError()); // Prepare LLVM infrastructure for JIT compilation llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); mlir::registerLLVMDialectTranslation(mlirContext); auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); llvm::Expected> lambdaOrErr = mlir::concretelang::JITLambda::create(funcName, module, optPipeline, runtimeLibPath); if (!lambdaOrErr) { return StreamStringError() << "Cannot create lambda: " << lambdaOrErr.takeError(); } auto lambda = std::move(lambdaOrErr.get()); // Generate the KeySet for encrypting lambda arguments, decrypting lambda // results if (!compRes.clientParameters.hasValue()) { return StreamStringError("Cannot generate the keySet since client " "parameters has not been computed"); } llvm::Expected> keySetOrErr = (cache.hasValue()) ? cache->tryLoadOrGenerateSave(*compRes.clientParameters, 0, 0) : KeySet::generate(*compRes.clientParameters, 0, 0); if (!keySetOrErr) { return keySetOrErr.takeError(); } auto keySet = std::move(keySetOrErr.get()); return Lambda{this->compilationContext, std::move(lambda), std::move(keySet)}; } } // namespace concretelang } // namespace mlir