#include #include #include #include #include #include #include #include #include #include #include #include #include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" #include "zamalang/Support/CompilerTools.h" namespace cmdline { llvm::cl::list inputs(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::OneOrMore); llvm::cl::opt output("o", llvm::cl::desc("Specify output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("-")); llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); llvm::cl::list passes( "passes", llvm::cl::desc("Specify the passes to run (use only for compiler tests)"), llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore); llvm::cl::opt roundTrip("round-trip", llvm::cl::desc("Just parse and dump"), llvm::cl::init(false)); llvm::cl::opt verifyDiagnostics( "verify-diagnostics", llvm::cl::desc("Check that emitted diagnostics match " "expected-* lines on the corresponding line"), llvm::cl::init(false)); llvm::cl::opt splitInputFile( "split-input-file", llvm::cl::desc("Split the input file into pieces and process each " "chunk independently"), llvm::cl::init(false)); llvm::cl::opt generateKeySet( "generate-keyset", llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"), llvm::cl::init(false)); llvm::cl::opt runJit("run-jit", llvm::cl::desc("JIT the code and run it"), llvm::cl::init(false)); llvm::cl::opt jitFuncname( "jit-funcname", llvm::cl::desc("Name of the function to execute, default 'main'"), llvm::cl::init("main")); llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); llvm::cl::opt toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "), llvm::cl::init(false)); }; // namespace cmdline #define LOG_VERBOSE(expr) \ if (cmdline::verbose) \ llvm::errs() << expr; #define LOG_ERROR(expr) llvm::errs() << expr; auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { llvm::LLVMContext context; auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule( context, module, defaultOptPipeline); if (!llvmModule) { return mlir::failure(); } os << **llvmModule; return mlir::success(); } mlir::LogicalResult runJit(mlir::ModuleOp module, mlir::zamalang::KeySet &keySet, llvm::raw_ostream &os) { // Create the JIT lambda auto maybeLambda = mlir::zamalang::JITLambda::create( cmdline::jitFuncname, module, defaultOptPipeline); if (!maybeLambda) { return mlir::failure(); } auto lambda = std::move(maybeLambda.get()); // Create the arguments of the JIT lambda auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); if (auto err = maybeArguments.takeError()) { LOG_ERROR("Cannot create lambda arguments: " << err << "\n"); return mlir::failure(); } // Set the arguments auto arguments = std::move(maybeArguments.get()); for (auto i = 0; i < cmdline::jitArgs.size(); i++) { if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) { LOG_ERROR("Cannot push argument " << i << ": " << err << "\n"); return mlir::failure(); } } // Invoke the lambda if (auto err = lambda->invoke(*arguments)) { LOG_ERROR("Cannot invoke : " << err << "\n"); return mlir::failure(); } uint64_t res = 0; if (auto err = arguments->getResult(0, res)) { LOG_ERROR("Cannot get result : " << err << "\n"); return mlir::failure(); } llvm::errs() << res << "\n"; return mlir::success(); } // Process a single source buffer // // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using // `expected-error` are produced. // // If `verifyDiagnostics` is `false`, the procedure checks if the // parsed module is valid and if all requested transformations // succeeded. mlir::LogicalResult processInputBuffer(mlir::MLIRContext &context, std::unique_ptr buffer, llvm::raw_ostream &os, bool verifyDiagnostics) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); auto module = mlir::parseSourceFile(sourceMgr, &context); if (verifyDiagnostics) return sourceMgrHandler.verify(); if (!module) return mlir::failure(); if (cmdline::roundTrip) { module->print(os); return mlir::success(); } auto enablePass = [](std::string passName) { return cmdline::passes.size() == 0 || std::any_of(cmdline::passes.begin(), cmdline::passes.end(), [&](const std::string &p) { return passName == p; }); }; // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. mlir::zamalang::CompilerTools::LowerOptions lowerOptions; lowerOptions.enablePass = enablePass; lowerOptions.verbose = cmdline::verbose; mlir::zamalang::V0FHEContext fheContext; if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( context, *module, fheContext, lowerOptions) .failed()) { return mlir::failure(); } LOG_VERBOSE("### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 << ", p:" << fheContext.constraint.p << "}\n"); LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: " << fheContext.parameter.k << ", polynomialSize: " << fheContext.parameter.polynomialSize << ", nSmall: " << fheContext.parameter.nSmall << ", brLevel: " << fheContext.parameter.brLevel << ", brLogBase: " << fheContext.parameter.brLogBase << ", ksLevel: " << fheContext.parameter.ksLevel << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"); // Generate the keySet std::unique_ptr keySet; if (cmdline::generateKeySet || cmdline::runJit) { // Create the client parameters auto clientParameter = mlir::zamalang::createClientParametersForV0( fheContext, cmdline::jitFuncname, *module); if (auto err = clientParameter.takeError()) { LOG_ERROR("cannot generate client parameters: " << err << "\n"); return mlir::failure(); } LOG_VERBOSE("### Generate the key set\n"); auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); // TODO: seed if (auto err = maybeKeySet.takeError()) { llvm::errs() << err; return mlir::failure(); } keySet = std::move(maybeKeySet.get()); } // Lower to MLIR LLVM Dialect if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( context, *module, lowerOptions) .failed()) { return mlir::failure(); } if (cmdline::runJit) { LOG_VERBOSE("### JIT compile & running\n"); return runJit(module.get(), *keySet, os); } if (cmdline::toLLVM) { return dumpLLVMIR(module.get(), os); } module->print(os); return mlir::success(); } mlir::LogicalResult compilerMain(int argc, char **argv) { // Parse command line arguments llvm::cl::ParseCommandLineOptions(argc, argv); // Initialize the MLIR context mlir::MLIRContext context; // String for error messages from library functions std::string errorMessage; // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); context.getOrLoadDialect(); if (cmdline::verifyDiagnostics) context.printOpOnDiagnostic(false); auto output = mlir::openOutputFile(cmdline::output, &errorMessage); if (!output) { llvm::errs() << errorMessage << "\n"; return mlir::failure(); } // Iterate over all inpiut files specified on the command line for (const auto &fileName : cmdline::inputs) { auto file = mlir::openInputFile(fileName, &errorMessage); if (!file) { llvm::errs() << errorMessage << "\n"; return mlir::failure(); } // If `--split-input-file` is set, the file is split into // individual chunks separated by `// -----` markers. Each chunk // is then processed individually as if it were part of a separate // source file. if (cmdline::splitInputFile) { if (mlir::failed(mlir::splitAndProcessBuffer( std::move(file), [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { return processInputBuffer(context, std::move(inputBuffer), os, cmdline::verifyDiagnostics); }, output->os()))) return mlir::failure(); } else { return processInputBuffer(context, std::move(file), output->os(), cmdline::verifyDiagnostics); } } return mlir::success(); } int main(int argc, char **argv) { if (mlir::failed(compilerMain(argc, argv))) return 1; return 0; }