#include #include #include #include #include #include #include #include #include #include #include #include #include #include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" #include "zamalang/Conversion/Utils/GlobalFHEContext.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" #include "zamalang/Support/Jit.h" #include "zamalang/Support/KeySet.h" #include "zamalang/Support/Pipeline.h" #include "zamalang/Support/logging.h" enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; enum Action { ROUND_TRIP, DUMP_HLFHE_MANP, DUMP_MIDLFHE, DUMP_LOWLFHE, DUMP_STD, DUMP_LLVM_DIALECT, DUMP_LLVM_IR, DUMP_OPTIMIZED_LLVM_IR, JIT_INVOKE }; namespace cmdline { class OptionalSizeTParser : public llvm::cl::parser> { public: OptionalSizeTParser(llvm::cl::Option &option) : llvm::cl::parser>(option) {} bool parse(llvm::cl::Option &option, llvm::StringRef argName, llvm::StringRef arg, llvm::Optional &value) { size_t parsedVal; std::istringstream iss(arg.str()); iss >> parsedVal; if (iss.fail()) return option.error("Invalid value " + arg); value.emplace(parsedVal); return false; } }; llvm::cl::list inputs(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::OneOrMore); llvm::cl::opt output("o", llvm::cl::desc("Specify output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("-")); llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); llvm::cl::opt parametrizeMidLFHE( "parametrize-midlfhe", llvm::cl::desc("Perform MidLFHE global parametrization pass"), llvm::cl::init(true)); static llvm::cl::opt entryDialect( "e", "entry-dialect", llvm::cl::desc("Entry dialect"), llvm::cl::init(EntryDialect::HLFHE), llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, llvm::cl::values( clEnumValN(EntryDialect::HLFHE, "hlfhe", "Input module is composed of HLFHE operations")), llvm::cl::values( clEnumValN(EntryDialect::MIDLFHE, "midlfhe", "Input module is composed of MidLFHE operations")), llvm::cl::values( clEnumValN(EntryDialect::LOWLFHE, "lowlfhe", "Input module is composed of LowLFHE operations")), llvm::cl::values( clEnumValN(EntryDialect::STD, "std", "Input module is composed of operations from std")), llvm::cl::values( clEnumValN(EntryDialect::LLVM, "llvm", "Input module is composed of operations from llvm"))); static llvm::cl::opt action( "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, llvm::cl::values( clEnumValN(Action::ROUND_TRIP, "roundtrip", "Parse input module and regenerate textual representation")), llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp", "Dump HLFHE module after running the Minimal " "Arithmetic Noise Padding pass")), llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", "Lower to MidLFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", "Lower to LowLFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std", "Lower to std and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect", "Lower to LLVM dialect and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LLVM_IR, "dump-llvm-ir", "Lower to LLVM-IR and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_OPTIMIZED_LLVM_IR, "dump-optimized-llvm-ir", "Lower to LLVM-IR, optimize and dump result")), llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke", "Lower and JIT-compile input module and invoke " "function specified with --jit-funcname"))); llvm::cl::opt verifyDiagnostics( "verify-diagnostics", llvm::cl::desc("Check that emitted diagnostics match " "expected-* lines on the corresponding line"), llvm::cl::init(false)); llvm::cl::opt splitInputFile( "split-input-file", llvm::cl::desc("Split the input file into pieces and process each " "chunk independently"), llvm::cl::init(false)); llvm::cl::opt jitFuncName( "jit-funcname", llvm::cl::desc("Name of the function to execute, default 'main'"), llvm::cl::init("main")); llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); llvm::cl::opt, false, OptionalSizeTParser> assumeMaxEintPrecision( "assume-max-eint-precision", llvm::cl::desc("Assume a maximum precision for encrypted integers")); llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( "assume-max-manp", llvm::cl::desc( "Assume a maximum for the Minimum Arithmetic Noise Padding")); }; // namespace cmdline std::function defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); std::unique_ptr generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, const std::string &jitFuncName) { std::unique_ptr keySet; mlir::zamalang::log_verbose() << "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 << ", p:" << fheContext.constraint.p << "}\n"; mlir::zamalang::log_verbose() << "### FHE parameters for the atomic pattern: {k: " << fheContext.parameter.k << ", polynomialSize: " << fheContext.parameter.polynomialSize << ", nSmall: " << fheContext.parameter.nSmall << ", brLevel: " << fheContext.parameter.brLevel << ", brLogBase: " << fheContext.parameter.brLogBase << ", ksLevel: " << fheContext.parameter.ksLevel << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"; // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( fheContext, jitFuncName, module); if (auto err = clientParameter.takeError()) { mlir::zamalang::log_error() << "cannot generate client parameters: " << err << "\n"; return nullptr; } mlir::zamalang::log_verbose() << "### Generate the key set\n"; auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); // TODO: seed if (auto err = maybeKeySet.takeError()) { llvm::errs() << err; return nullptr; } return std::move(maybeKeySet.get()); } llvm::Expected buildFHEContext( llvm::Optional autoFHEConstraints, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP) { if (!autoFHEConstraints.hasValue() && (!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) { return llvm::make_error( "Maximum encrypted integer precision and maximum for the Minimal" "Arithmetic Noise Passing are required, but were neither specified" "explicitly nor determined automatically", llvm::inconvertibleErrorCode()); } mlir::zamalang::V0FHEConstraint fheConstraints{ .norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() : autoFHEConstraints.getValue().norm2, .p = overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue() : autoFHEConstraints.getValue().p}; const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); if (!parameter) { std::string buffer; llvm::raw_string_ostream strs(buffer); strs << "Could not determine V0 parameters for 2-norm of " << fheConstraints.norm2 << " and p of " << fheConstraints.p; return llvm::make_error(strs.str(), llvm::inconvertibleErrorCode()); } return mlir::zamalang::V0FHEContext{fheConstraints, *parameter}; } mlir::LogicalResult buildAssignFHEContext( llvm::Optional &fheContext, llvm::Optional autoFHEConstraints, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP) { if (fheContext.hasValue()) return mlir::success(); llvm::Expected fheContextOrErr = buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision, overrideMaxMANP); if (auto err = fheContextOrErr.takeError()) { mlir::zamalang::log_error() << err; return mlir::failure(); } fheContext.emplace(fheContextOrErr.get()); return mlir::success(); } // Process a single source buffer // // The parameter `entryDialect` must specify the FHE dialect to which // belong all FHE operations used in the source buffer. The input // program must only contain FHE operations from that single FHE // dialect, otherwise processing might fail. // // The parameter `action` specifies how the buffer should be processed // and thus defines the output. // // If the specified action involves JIT compilation, `jitFuncName` // designates the function to JIT compile. This function is invoked // using the parameters given in `jitArgs`. // // The parameter `parametrizeMidLFHE` defines, whether the // parametrization pass for MidLFHE is executed. If the pair of // `entryDialect` and `action` does not involve any MidlFHE // manipulation, this parameter does not have any effect. // // The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if // set, override the values for the maximum required precision of // encrypted integers and the maximum value for the Minimum Arithmetic // Noise Padding otherwise determined automatically if the entry // dialect is HLFHE.. // // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using // `expected-error` are produced. If `verifyDiagnostics` is `false`, // the procedure checks if the parsed module is valid and if all // requested transformations succeeded. // // If `verbose` is true, debug messages are displayed throughout the // compilation process. // // Compilation output is written to the stream specified by `os`. mlir::LogicalResult processInputBuffer( mlir::MLIRContext &context, std::unique_ptr buffer, enum EntryDialect entryDialect, enum Action action, const std::string &jitFuncName, llvm::ArrayRef jitArgs, bool parametrizeMidlHFE, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, bool verbose, llvm::raw_ostream &os) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); llvm::Optional fheConstraints; llvm::Optional fheContext; std::unique_ptr keySet = nullptr; if (verbose) context.disableMultithreading(); if (verifyDiagnostics) return sourceMgrHandler.verify(); if (!moduleRef) return mlir::failure(); mlir::ModuleOp module = moduleRef.get(); if (action == Action::ROUND_TRIP) { module->print(os); return mlir::success(); } // Lowering pipeline. Each stage is represented as a label in the // switch statement, from the most abstract dialect to the lowest // level. Every labels acts as an entry point into the pipeline with // a fallthrough mechanism to the next stage. Actions act as exit // points from the pipeline. switch (entryDialect) { case EntryDialect::HLFHE: if (action == Action::DUMP_HLFHE_MANP) { if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) .failed()) { return mlir::failure(); } module.print(os); return mlir::success(); } else { llvm::Expected> fheConstraintsOrErr = mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context, module); if (auto err = fheConstraintsOrErr.takeError()) { mlir::zamalang::log_error() << err; return mlir::failure(); } else { fheConstraints = fheConstraintsOrErr.get(); } } if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) .failed()) return mlir::failure(); // fallthrough case EntryDialect::MIDLFHE: if (action == Action::DUMP_MIDLFHE) { module.print(os); return mlir::success(); } if (buildAssignFHEContext(fheContext, fheConstraints, overrideMaxEintPrecision, overrideMaxMANP) .failed()) { return mlir::failure(); } if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( context, module, fheContext.getValue(), parametrizeMidlHFE) .failed()) return mlir::failure(); // fallthrough case EntryDialect::LOWLFHE: if (action == Action::DUMP_LOWLFHE) { module.print(os); return mlir::success(); } if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed()) return mlir::failure(); // fallthrough case EntryDialect::STD: if (action == Action::DUMP_STD) { module.print(os); return mlir::success(); } else if (action == Action::JIT_INVOKE) { if (buildAssignFHEContext(fheContext, fheConstraints, overrideMaxEintPrecision, overrideMaxMANP) .failed()) { return mlir::failure(); } keySet = generateKeySet(module, fheContext.getValue(), jitFuncName); } if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, verbose) .failed()) return mlir::failure(); // fallthrough case EntryDialect::LLVM: { if (action == Action::DUMP_LLVM_DIALECT) { module.print(os); return mlir::success(); } else if (action == Action::JIT_INVOKE) { return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet, defaultOptPipeline, os); } llvm::LLVMContext llvmContext; std::unique_ptr llvmModule = mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext, module); if (!llvmModule) { mlir::zamalang::log_error() << "Failed to translate LLVM dialect to LLVM IR\n"; return mlir::failure(); } if (action == Action::DUMP_LLVM_IR) { llvmModule->dump(); return mlir::success(); } if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule) .failed()) { mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n"; return mlir::failure(); } if (action == Action::DUMP_OPTIMIZED_LLVM_IR) { llvmModule->dump(); return mlir::success(); } break; } } return mlir::success(); } mlir::LogicalResult compilerMain(int argc, char **argv) { // Parse command line arguments llvm::cl::ParseCommandLineOptions(argc, argv); // Initialize the MLIR context mlir::MLIRContext context; mlir::zamalang::setupLogging(cmdline::verbose); // String for error messages from library functions std::string errorMessage; if (cmdline::action == Action::DUMP_HLFHE_MANP && cmdline::entryDialect != EntryDialect::HLFHE) { mlir::zamalang::log_error() << "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs"; return mlir::failure(); } if (cmdline::action == Action::JIT_INVOKE && cmdline::entryDialect != EntryDialect::HLFHE && cmdline::entryDialect != EntryDialect::MIDLFHE && cmdline::entryDialect != EntryDialect::LOWLFHE && cmdline::entryDialect != EntryDialect::STD) { mlir::zamalang::log_error() << "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs"; return mlir::failure(); } // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); if (cmdline::verifyDiagnostics) context.printOpOnDiagnostic(false); auto output = mlir::openOutputFile(cmdline::output, &errorMessage); if (!output) { llvm::errs() << errorMessage << "\n"; return mlir::failure(); } // Iterate over all input files specified on the command line for (const auto &fileName : cmdline::inputs) { auto file = mlir::openInputFile(fileName, &errorMessage); if (!file) { llvm::errs() << errorMessage << "\n"; return mlir::failure(); } // If `--split-input-file` is set, the file is split into // individual chunks separated by `// -----` markers. Each chunk // is then processed individually as if it were part of a separate // source file. if (cmdline::splitInputFile) { if (mlir::failed(mlir::splitAndProcessBuffer( std::move(file), [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { return processInputBuffer( context, std::move(inputBuffer), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, cmdline::verbose, os); }, output->os()))) return mlir::failure(); } else { return processInputBuffer( context, std::move(file), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, cmdline::verbose, output->os()); } } return mlir::success(); } int main(int argc, char **argv) { if (mlir::failed(compilerMain(argc, argv))) return 1; return 0; }