#include "zamalang/Support/CompilerEngine.h" #include "zamalang/Conversion/Passes.h" #include #include namespace mlir { namespace zamalang { void CompilerEngine::loadDialects() { context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); } std::string CompilerEngine::getCompiledModule() { std::string compiledModule; llvm::raw_string_ostream os(compiledModule); module_ref->print(os); return os.str(); } llvm::Expected CompilerEngine::compileFHE(std::string mlir_input) { module_ref = mlir::parseSourceString(mlir_input, context); if (!module_ref) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } mlir::zamalang::V0FHEContext fheContext; // Lower to MLIR Std if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( *context, module_ref.get(), fheContext) .failed()) { return llvm::make_error("failed to lower to MLIR Std", llvm::inconvertibleErrorCode()); } // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( fheContext, "main", module_ref.get()); if (auto err = clientParameter.takeError()) { return llvm::make_error( "cannot generate client parameters", llvm::inconvertibleErrorCode()); } auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); if (auto err = maybeKeySet.takeError()) { return llvm::make_error("cannot generate keyset", llvm::inconvertibleErrorCode()); } keySet = std::move(maybeKeySet.get()); // Lower to MLIR LLVM Dialect if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( *context, module_ref.get()) .failed()) { return llvm::make_error( "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); } return mlir::success(); } llvm::Expected CompilerEngine::run(std::vector args) { // Create the JIT lambda auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); auto module = module_ref.get(); auto maybeLambda = mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); if (!maybeLambda) { return llvm::make_error("couldn't create lambda", llvm::inconvertibleErrorCode()); } auto lambda = std::move(maybeLambda.get()); // Create the arguments of the JIT lambda auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet); if (auto err = maybeArguments.takeError()) { return llvm::make_error("cannot create lambda args", llvm::inconvertibleErrorCode()); } // Set the arguments auto arguments = std::move(maybeArguments.get()); for (auto i = 0; i < args.size(); i++) { if (auto err = arguments->setArg(i, args[i])) { return llvm::make_error( "cannot push argument", llvm::inconvertibleErrorCode()); } } // Invoke the lambda if (lambda->invoke(*arguments)) { return llvm::make_error("failed execution", llvm::inconvertibleErrorCode()); } uint64_t res = 0; if (auto err = arguments->getResult(0, res)) { return llvm::make_error("cannot get result", llvm::inconvertibleErrorCode()); } return res; } } // namespace zamalang } // namespace mlir