Files
concrete/compiler/src/main.cpp
2021-09-09 20:35:28 +02:00

307 lines
11 KiB
C++

#include <iostream>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
namespace cmdline {
llvm::cl::list<std::string> inputs(llvm::cl::Positional,
llvm::cl::desc("<Input files>"),
llvm::cl::OneOrMore);
llvm::cl::opt<std::string> output("o",
llvm::cl::desc("Specify output filename"),
llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> roundTrip("round-trip",
llvm::cl::desc("Just parse and dump"),
llvm::cl::init(false));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> generateKeySet(
"generate-keyset",
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::opt<std::string> jitFuncname(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::list<uint64_t>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
#define LOG_VERBOSE(expr) \
if (cmdline::verbose) \
llvm::errs() << expr;
#define LOG_ERROR(expr) llvm::errs() << expr;
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
context, module, defaultOptPipeline);
if (!llvmModule) {
return mlir::failure();
}
os << **llvmModule;
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module,
mlir::zamalang::KeySet &keySet,
llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda = mlir::zamalang::JITLambda::create(
cmdline::jitFuncname, module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = std::move(maybeLambda.get());
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
if (auto err = maybeArguments.takeError()) {
LOG_ERROR("Cannot create lambda arguments: " << err << "\n");
return mlir::failure();
}
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
LOG_ERROR("Cannot push argument " << i << ": " << err << "\n");
return mlir::failure();
}
}
// Invoke the lambda
if (auto err = lambda->invoke(*arguments)) {
LOG_ERROR("Cannot invoke : " << err << "\n");
return mlir::failure();
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
LOG_ERROR("Cannot get result : " << err << "\n");
return mlir::failure();
}
llvm::errs() << res << "\n";
return mlir::success();
}
// Process a single source buffer
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced.
//
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
auto module = mlir::parseSourceFile(sourceMgr, &context);
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (!module)
return mlir::failure();
if (cmdline::roundTrip) {
module->print(os);
return mlir::success();
}
auto enablePass = [](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::CompilerTools::LowerOptions lowerOptions;
lowerOptions.enablePass = enablePass;
lowerOptions.verbose = cmdline::verbose;
mlir::zamalang::V0FHEContext fheContext;
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, fheContext, lowerOptions)
.failed()) {
return mlir::failure();
}
LOG_VERBOSE("### Global FHE constraint: {norm2:"
<< fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n");
LOG_VERBOSE("### FHE parameters for the atomic pattern: {k: "
<< fheContext.parameter.k
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
<< ", nSmall: " << fheContext.parameter.nSmall
<< ", brLevel: " << fheContext.parameter.brLevel
<< ", brLogBase: " << fheContext.parameter.brLogBase
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n");
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
LOG_ERROR("cannot generate client parameters: " << err << "\n");
return mlir::failure();
}
LOG_VERBOSE("### Generate the key set\n");
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return mlir::failure();
}
keySet = std::move(maybeKeySet.get());
}
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
context, *module, lowerOptions)
.failed()) {
return mlir::failure();
}
if (cmdline::runJit) {
LOG_VERBOSE("### JIT compile & running\n");
return runJit(module.get(), *keySet, os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);
}
module->print(os);
return mlir::success();
}
mlir::LogicalResult compilerMain(int argc, char **argv) {
// Parse command line arguments
llvm::cl::ParseCommandLineOptions(argc, argv);
// Initialize the MLIR context
mlir::MLIRContext context;
// String for error messages from library functions
std::string errorMessage;
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
if (cmdline::verifyDiagnostics)
context.printOpOnDiagnostic(false);
auto output = mlir::openOutputFile(cmdline::output, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
// Iterate over all inpiut files specified on the command line
for (const auto &fileName : cmdline::inputs) {
auto file = mlir::openInputFile(fileName, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
// If `--split-input-file` is set, the file is split into
// individual chunks separated by `// -----` markers. Each chunk
// is then processed individually as if it were part of a separate
// source file.
if (cmdline::splitInputFile) {
if (mlir::failed(mlir::splitAndProcessBuffer(
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics);
}
}
return mlir::success();
}
int main(int argc, char **argv) {
if (mlir::failed(compilerMain(argc, argv)))
return 1;
return 0;
}