refactor(compiler): Introduce compilation pipeline with multiple entries / exits

This refactoring commit restructures the compilation pipeline of
`zamacompiler`, such that it is possible to enter and exit the
pipeline at different points, effectively defining the level of
abstraction at the input and the required level of abstraction for the
output.

The entry point is specified using the `--entry-dialect`
argument. Valid choices are:

  `--entry-dialect=hlfhe`:   Source contains HLFHE operations
  `--entry-dialect=midlfhe`: Source contains MidLFHE operations
  `--entry-dialect=lowlfhe`: Source contains LowLFHE operations
  `--entry-dialect=std`:     Source does not contain any FHE Operations
  `--entry-dialect=llvm`:    Source is in LLVM dialect

The exit point is defined by an action, specified using --action.

  `--action=roundtrip`:
     Parse the source file to in-memory representation and immediately
     dump as text without any processing

  `--action=dump-midlfhe`:
     Lower source to MidLFHE and dump result as text

  `--action=dump-lowlfhe`:
     Lower source to LowLFHE and dump result as text

  `--action=dump-std`:
     Lower source to only standard MLIR dialects (i.e., all FHE
     operations have already been lowered)

  `--action=dump-llvm-dialect`:
     Lower source to MLIR's LLVM dialect (i.e., the LLVM dialect, not
     LLVM IR)

  `--action=dump-llvm-ir`:
     Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
     actual LLVM IR)

  `--action=dump-optimized-llvm-ir`:
     Lower source to plain LLVM IR (i.e., not the LLVM dialect, but
     actual LLVM IR), pass the result through the LLVM optimizer and
     print the result.

  `--action=dump-jit-invoke`:
     Execute the full lowering pipeline to optimized LLVM IR, JIT
     compile the result, invoke the function specified in
     `--jit-funcname` with the parameters from `--jit-args` and print
     the functions return value.
This commit is contained in:
Andi Drebes
2021-09-17 10:45:53 +02:00
committed by Quentin Bourgerie
parent ddebedd1d6
commit 30374ebb2c
58 changed files with 1014 additions and 862 deletions

View File

@@ -12,16 +12,32 @@
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "mlir/IR/BuiltinOps.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
#include "zamalang/Support/logging.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/Pipeline.h"
#include "zamalang/Support/logging.h"
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
DUMP_STD,
DUMP_LLVM_DIALECT,
DUMP_LLVM_IR,
DUMP_OPTIMIZED_LLVM_IR,
JIT_INVOKE
};
namespace cmdline {
@@ -37,14 +53,53 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> parametrizeMidLFHE(
"parametrize-midlfhe",
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
llvm::cl::init<bool>(true));
llvm::cl::opt<bool> roundTrip("round-trip",
llvm::cl::desc("Just parse and dump"),
llvm::cl::init(false));
static llvm::cl::opt<enum EntryDialect> entryDialect(
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(EntryDialect::HLFHE, "hlfhe",
"Input module is composed of HLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
"Input module is composed of MidLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
"Input module is composed of LowLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::STD, "std",
"Input module is composed of operations from std")),
llvm::cl::values(
clEnumValN(EntryDialect::LLVM, "llvm",
"Input module is composed of operations from llvm")));
static llvm::cl::opt<enum Action> action(
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(Action::ROUND_TRIP, "roundtrip",
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
"Lower to LowLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std",
"Lower to std and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect",
"Lower to LLVM dialect and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LLVM_IR, "dump-llvm-ir",
"Lower to LLVM-IR and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_OPTIMIZED_LLVM_IR,
"dump-optimized-llvm-ir",
"Lower to LLVM-IR, optimize and dump result")),
llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke",
"Lower and JIT-compile input module and invoke "
"function specified with --jit-funcname")));
llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
@@ -58,15 +113,7 @@ llvm::cl::opt<bool> splitInputFile(
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> generateKeySet(
"generate-keyset",
llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::opt<std::string> jitFuncname(
llvm::cl::opt<std::string> jitFuncName(
"jit-funcname",
llvm::cl::desc("Name of the function to execute, default 'main'"),
llvm::cl::init<std::string>("main"));
@@ -75,73 +122,16 @@ llvm::cl::list<uint64_t>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
context, module, defaultOptPipeline);
if (!llvmModule) {
return mlir::failure();
}
os << **llvmModule;
return mlir::success();
}
std::unique_ptr<mlir::zamalang::KeySet>
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
const std::string &jitFuncName) {
std::unique_ptr<mlir::zamalang::KeySet> keySet;
// Process a single source buffer
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced.
//
// If `verifyDiagnostics` is `false`, the procedure checks if the
// parsed module is valid and if all requested transformations
// succeeded.
mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
auto module = mlir::parseSourceFile(sourceMgr, &context);
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (!module)
return mlir::failure();
if (cmdline::roundTrip) {
module->print(os);
return mlir::success();
}
auto enablePass = [](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
};
// Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit.
mlir::zamalang::CompilerTools::LowerOptions lowerOptions;
lowerOptions.enablePass = enablePass;
lowerOptions.verbose = cmdline::verbose;
mlir::zamalang::V0FHEContext fheContext;
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
context, *module, fheContext, lowerOptions)
.failed()) {
return mlir::failure();
}
mlir::zamalang::log_verbose()
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n";
@@ -155,45 +145,196 @@ processInputBuffer(mlir::MLIRContext &context,
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
// Generate the keySet
std::unique_ptr<mlir::zamalang::KeySet> keySet;
if (cmdline::generateKeySet || cmdline::runJit) {
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, cmdline::jitFuncname, *module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return mlir::failure();
}
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return mlir::failure();
}
keySet = std::move(maybeKeySet.get());
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, jitFuncName, module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return nullptr;
}
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
context, *module, lowerOptions)
.failed()) {
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return nullptr;
}
return std::move(maybeKeySet.get());
}
// Process a single source buffer
//
// The parameter `entryDialect` must specify the FHE dialect to which
// belong all FHE operations used in the source buffer. The input
// program must only contain FHE operations from that single FHE
// dialect, otherwise processing might fail.
//
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
// If the specified action involves JIT compilation, `jitFuncName`
// designates the function to JIT compile. This function is invoked
// using the parameters given in `jitArgs`.
//
// The parameter `parametrizeMidLFHE` defines, whether the
// parametrization pass for MidLFHE is executed. If the pair of
// `entryDialect` and `action` does not involve any MidlFHE
// manipulation, this parameter does not have any effect.
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced. If `verifyDiagnostics` is `false`,
// the procedure checks if the parsed module is valid and if all
// requested transformations succeeded.
//
// If `verbose` is true, debug messages are displayed throughout the
// compilation process.
//
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
enum EntryDialect entryDialect, enum Action action,
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose,
llvm::raw_ostream &os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
// This is temporary until we have the high-level verification pass
// determining these parameters automatically
mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10,
.p = 7};
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
const mlir::zamalang::V0Parameter *parameter =
getV0Parameter(defaultGlobalFHECircuitConstraint);
if (!parameter) {
mlir::zamalang::log_error()
<< "Could not determine V0 parameters for 2-norm of "
<< defaultGlobalFHECircuitConstraint.norm2 << " and p of "
<< defaultGlobalFHECircuitConstraint.p << "\n";
return mlir::failure();
}
if (cmdline::runJit) {
mlir::zamalang::log_verbose() << "### JIT compile & running\n";
return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname,
cmdline::jitArgs, *keySet,
defaultOptPipeline, os);
mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint,
*parameter};
if (verbose)
context.disableMultithreading();
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (!moduleRef)
return mlir::failure();
mlir::ModuleOp module = moduleRef.get();
if (action == Action::ROUND_TRIP) {
module->print(os);
return mlir::success();
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);
// Lowering pipeline. Each stage is represented as a label in the
// switch statement, from the most abstract dialect to the lowest
// level. Every labels acts as an entry point into the pipeline with
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::MIDLFHE:
if (action == Action::DUMP_MIDLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
context, module, fheContext, parametrizeMidlHFE)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LOWLFHE:
if (action == Action::DUMP_LOWLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
return mlir::failure();
// fallthrough
case EntryDialect::STD:
if (action == Action::DUMP_STD) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
keySet = generateKeySet(module, fheContext, jitFuncName);
}
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LLVM: {
if (action == Action::DUMP_LLVM_DIALECT) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
defaultOptPipeline, os);
}
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
module);
if (!llvmModule) {
mlir::zamalang::log_error()
<< "Failed to translate LLVM dialect to LLVM IR\n";
return mlir::failure();
}
if (action == Action::DUMP_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
.failed()) {
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
return mlir::failure();
}
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
break;
}
module->print(os);
}
return mlir::success();
}
@@ -209,6 +350,16 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
// String for error messages from library functions
std::string errorMessage;
if (cmdline::action == Action::JIT_INVOKE &&
cmdline::entryDialect != EntryDialect::HLFHE &&
cmdline::entryDialect != EntryDialect::MIDLFHE &&
cmdline::entryDialect != EntryDialect::LOWLFHE &&
cmdline::entryDialect != EntryDialect::STD) {
mlir::zamalang::log_error()
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
return mlir::failure();
}
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
@@ -229,7 +380,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
return mlir::failure();
}
// Iterate over all inpiut files specified on the command line
// Iterate over all input files specified on the command line
for (const auto &fileName : cmdline::inputs) {
auto file = mlir::openInputFile(fileName, &errorMessage);
@@ -247,14 +398,19 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
std::move(file),
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(context, std::move(inputBuffer), os,
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(inputBuffer), cmdline::entryDialect,
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics,
cmdline::verbose, os);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(context, std::move(file), output->os(),
cmdline::verifyDiagnostics);
return processInputBuffer(
context, std::move(file), cmdline::entryDialect, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
}
}