refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -1,3 +1,4 @@
#include <cstdint>
#include <iostream>
#include <llvm/Support/CommandLine.h>
@@ -22,15 +23,15 @@
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/Error.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/KeySet.h"
#include "zamalang/Support/Pipeline.h"
#include "zamalang/Support/logging.h"
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DUMP_HLFHE,
DUMP_HLFHE_MANP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
@@ -80,26 +81,6 @@ llvm::cl::opt<bool> parametrizeMidLFHE(
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
llvm::cl::init<bool>(true));
static llvm::cl::opt<enum EntryDialect> entryDialect(
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
llvm::cl::values(
clEnumValN(EntryDialect::HLFHE, "hlfhe",
"Input module is composed of HLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
"Input module is composed of MidLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
"Input module is composed of LowLFHE operations")),
llvm::cl::values(
clEnumValN(EntryDialect::STD, "std",
"Input module is composed of operations from std")),
llvm::cl::values(
clEnumValN(EntryDialect::LLVM, "llvm",
"Input module is composed of operations from llvm")));
static llvm::cl::opt<enum Action> action(
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
llvm::cl::NumOccurrencesFlag::Required,
@@ -109,6 +90,8 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
"Dump HLFHE module after running the Minimal "
"Arithmetic Noise Padding pass")),
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe",
"Dump HLFHE module")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
@@ -158,50 +141,7 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
llvm::cl::desc(
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
}; // namespace cmdline
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
std::unique_ptr<mlir::zamalang::KeySet>
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
const std::string &jitFuncName) {
std::unique_ptr<mlir::zamalang::KeySet> keySet;
mlir::zamalang::log_verbose()
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
<< ", p:" << fheContext.constraint.p << "}\n";
mlir::zamalang::log_verbose()
<< "### FHE parameters for the atomic pattern: {k: "
<< fheContext.parameter.k
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
<< ", nSmall: " << fheContext.parameter.nSmall
<< ", brLevel: " << fheContext.parameter.brLevel
<< ", brLogBase: " << fheContext.parameter.brLogBase
<< ", ksLevel: " << fheContext.parameter.ksLevel
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
fheContext, jitFuncName, module);
if (auto err = clientParameter.takeError()) {
mlir::zamalang::log_error()
<< "cannot generate client parameters: " << err << "\n";
return nullptr;
}
mlir::zamalang::log_verbose() << "### Generate the key set\n";
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
0); // TODO: seed
if (auto err = maybeKeySet.takeError()) {
llvm::errs() << err;
return nullptr;
}
return std::move(maybeKeySet.get());
}
} // namespace cmdline
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
@@ -209,65 +149,48 @@ llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
llvm::Optional<size_t> overrideMaxMANP) {
if (!autoFHEConstraints.hasValue() &&
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
return llvm::make_error<llvm::StringError>(
return mlir::zamalang::StreamStringError(
"Maximum encrypted integer precision and maximum for the Minimal"
"Arithmetic Noise Passing are required, but were neither specified"
"explicitly nor determined automatically",
llvm::inconvertibleErrorCode());
"explicitly nor determined automatically");
}
mlir::zamalang::V0FHEConstraint fheConstraints{
.norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
: autoFHEConstraints.getValue().norm2,
.p = overrideMaxEintPrecision.hasValue()
? overrideMaxEintPrecision.getValue()
: autoFHEConstraints.getValue().p};
overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
: autoFHEConstraints.getValue().norm2,
overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue()
: autoFHEConstraints.getValue().p};
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
if (!parameter) {
std::string buffer;
llvm::raw_string_ostream strs(buffer);
strs << "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
return llvm::make_error<llvm::StringError>(strs.str(),
llvm::inconvertibleErrorCode());
return mlir::zamalang::StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
}
return mlir::zamalang::V0FHEContext{fheConstraints, *parameter};
}
mlir::LogicalResult buildAssignFHEContext(
llvm::Optional<mlir::zamalang::V0FHEContext> &fheContext,
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP) {
namespace llvm {
// This needs to be wrapped into the llvm namespace for proper
// operator lookup
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const llvm::ArrayRef<uint64_t> arr) {
os << "(";
for (size_t i = 0; i < arr.size(); i++) {
os << arr[i];
if (fheContext.hasValue())
return mlir::success();
llvm::Expected<mlir::zamalang::V0FHEContext> fheContextOrErr =
buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision,
overrideMaxMANP);
if (auto err = fheContextOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
if (i != arr.size() - 1)
os << ", ";
}
fheContext.emplace(fheContextOrErr.get());
return mlir::success();
return os;
}
} // namespace llvm
// Process a single source buffer
//
// The parameter `entryDialect` must specify the FHE dialect to which
// belong all FHE operations used in the source buffer. The input
// program must only contain FHE operations from that single FHE
// dialect, otherwise processing might fail.
//
// The parameter `action` specifies how the buffer should be processed
// and thus defines the output.
//
@@ -276,15 +199,14 @@ mlir::LogicalResult buildAssignFHEContext(
// using the parameters given in `jitArgs`.
//
// The parameter `parametrizeMidLFHE` defines, whether the
// parametrization pass for MidLFHE is executed. If the pair of
// `entryDialect` and `action` does not involve any MidlFHE
// manipulation, this parameter does not have any effect.
// parametrization pass for MidLFHE is executed. If the `action` does
// not involve any MidlFHE manipulation, this parameter does not have
// any effect.
//
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
// set, override the values for the maximum required precision of
// encrypted integers and the maximum value for the Minimum Arithmetic
// Noise Padding otherwise determined automatically if the entry
// dialect is HLFHE..
// Noise Padding otherwise determined automatically.
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
@@ -292,164 +214,103 @@ mlir::LogicalResult buildAssignFHEContext(
// the procedure checks if the parsed module is valid and if all
// requested transformations succeeded.
//
// If `verbose` is true, debug messages are displayed throughout the
// compilation process.
//
// Compilation output is written to the stream specified by `os`.
mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
enum EntryDialect entryDialect, enum Action action,
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
bool parametrizeMidlHFE, llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
bool verbose, llvm::raw_ostream &os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::LogicalResult
processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
enum Action action, const std::string &jitFuncName,
llvm::ArrayRef<uint64_t> jitArgs, bool parametrizeMidlHFE,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP,
bool verifyDiagnostics, llvm::raw_ostream &os) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
mlir::zamalang::JitCompilerEngine ce{ccx};
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setParametrizeMidLFHE(parametrizeMidlHFE);
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
if (overrideMaxEintPrecision.hasValue())
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
if (verbose)
context.disableMultithreading();
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
if (verifyDiagnostics)
return sourceMgrHandler.verify();
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), jitFuncName);
if (!moduleRef)
return mlir::failure();
mlir::ModuleOp module = moduleRef.get();
if (action == Action::ROUND_TRIP) {
module->print(os);
return mlir::success();
}
// Lowering pipeline. Each stage is represented as a label in the
// switch statement, from the most abstract dialect to the lowest
// level. Every labels acts as an entry point into the pipeline with
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
if (action == Action::DUMP_HLFHE_MANP) {
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
.failed()) {
return mlir::failure();
}
module.print(os);
return mlir::success();
} else {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context,
module);
if (auto err = fheConstraintsOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
} else {
fheConstraints = fheConstraintsOrErr.get();
}
}
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::MIDLFHE:
if (action == Action::DUMP_MIDLFHE) {
module.print(os);
return mlir::success();
}
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
context, module, fheContext.getValue(), parametrizeMidlHFE)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LOWLFHE:
if (action == Action::DUMP_LOWLFHE) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
return mlir::failure();
// fallthrough
case EntryDialect::STD:
if (action == Action::DUMP_STD) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
keySet = generateKeySet(module, fheContext.getValue(), jitFuncName);
}
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
verbose)
.failed())
return mlir::failure();
// fallthrough
case EntryDialect::LLVM: {
if (action == Action::DUMP_LLVM_DIALECT) {
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
defaultOptPipeline, os);
}
llvm::LLVMContext llvmContext;
std::unique_ptr<llvm::Module> llvmModule =
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
module);
if (!llvmModule) {
if (!lambdaOrErr) {
mlir::zamalang::log_error()
<< "Failed to translate LLVM dialect to LLVM IR\n";
<< "Failed to JIT-compile " << jitFuncName << ": "
<< llvm::toString(std::move(lambdaOrErr.takeError()));
return mlir::failure();
}
if (action == Action::DUMP_LLVM_IR) {
llvmModule->dump();
return mlir::success();
}
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
.failed()) {
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
if (!resOrErr) {
mlir::zamalang::log_error()
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
return mlir::failure();
}
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
llvmModule->dump();
return mlir::success();
os << *resOrErr << "\n";
} else {
enum mlir::zamalang::CompilerEngine::Target target;
switch (action) {
case Action::ROUND_TRIP:
target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP;
break;
case Action::DUMP_HLFHE:
target = mlir::zamalang::CompilerEngine::Target::HLFHE;
break;
case Action::DUMP_HLFHE_MANP:
target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP;
break;
case Action::DUMP_MIDLFHE:
target = mlir::zamalang::CompilerEngine::Target::MIDLFHE;
break;
case Action::DUMP_LOWLFHE:
target = mlir::zamalang::CompilerEngine::Target::LOWLFHE;
break;
case Action::DUMP_STD:
target = mlir::zamalang::CompilerEngine::Target::STD;
break;
case Action::DUMP_LLVM_DIALECT:
target = mlir::zamalang::CompilerEngine::Target::LLVM;
break;
case Action::DUMP_LLVM_IR:
target = mlir::zamalang::CompilerEngine::Target::LLVM_IR;
break;
case Action::DUMP_OPTIMIZED_LLVM_IR:
target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
break;
case JIT_INVOKE:
// Case just here to satisfy the compiler; already handled above
break;
}
break;
}
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
ce.compile(std::move(buffer), target);
if (!retOrErr) {
mlir::zamalang::log_error()
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
return mlir::failure();
}
if (verifyDiagnostics) {
return mlir::success();
} else if (action == Action::DUMP_LLVM_IR ||
action == Action::DUMP_OPTIMIZED_LLVM_IR) {
retOrErr->llvmModule->print(os, nullptr);
} else {
retOrErr->mlirModuleRef->get().print(os);
}
}
return mlir::success();
@@ -459,44 +320,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
// Parse command line arguments
llvm::cl::ParseCommandLineOptions(argc, argv);
// Initialize the MLIR context
mlir::MLIRContext context;
mlir::zamalang::setupLogging(cmdline::verbose);
// String for error messages from library functions
std::string errorMessage;
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
cmdline::entryDialect != EntryDialect::HLFHE) {
mlir::zamalang::log_error()
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
return mlir::failure();
}
if (cmdline::action == Action::JIT_INVOKE &&
cmdline::entryDialect != EntryDialect::HLFHE &&
cmdline::entryDialect != EntryDialect::MIDLFHE &&
cmdline::entryDialect != EntryDialect::LOWLFHE &&
cmdline::entryDialect != EntryDialect::STD) {
mlir::zamalang::log_error()
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
return mlir::failure();
}
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
if (cmdline::verifyDiagnostics)
context.printOpOnDiagnostic(false);
auto output = mlir::openOutputFile(cmdline::output, &errorMessage);
if (!output) {
@@ -523,20 +351,20 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
llvm::raw_ostream &os) {
return processInputBuffer(
context, std::move(inputBuffer), cmdline::entryDialect,
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
std::move(inputBuffer), cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, os);
cmdline::verifyDiagnostics, os);
},
output->os())))
return mlir::failure();
} else {
return processInputBuffer(
context, std::move(file), cmdline::entryDialect, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
std::move(file), cmdline::action, cmdline::jitFuncName,
cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
cmdline::verifyDiagnostics, output->os());
}
}