refactor(compiler): Simplify the compiler flow and re enable --passes compiler option

No more need to compute the fhe context at high level
This commit is contained in:
Quentin Bourgerie
2021-10-22 15:08:26 +02:00
committed by Andi Drebes
parent 41cba63113
commit 85d102c9b2
25 changed files with 200 additions and 332 deletions

View File

@@ -6,6 +6,7 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
#include <zamalang/Support/ClientParameters.h>
@@ -57,10 +58,6 @@ public:
// Read sources and exit before any lowering
HLFHE,
// Read sources and attempt to run the Minimal Arithmetic Noise
// Padding pass
HLFHE_MANP,
// Read sources and lower all HLFHE operations to MidLFHE
// operations
MIDLFHE,
@@ -91,7 +88,8 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(),
clientParametersFuncName(), verifyDiagnostics(false),
generateClientParameters(false), parametrizeMidLFHE(true),
generateClientParameters(false),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
@@ -106,8 +104,8 @@ public:
void setMaxMANP(size_t v);
void setVerifyDiagnostics(bool v);
void setGenerateClientParameters(bool v);
void setParametrizeMidLFHE(bool v);
void setClientParametersFuncName(const llvm::StringRef &name);
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
protected:
llvm::Optional<size_t> overrideMaxEintPrecision;
@@ -115,18 +113,14 @@ protected:
llvm::Optional<std::string> clientParametersFuncName;
bool verifyDiagnostics;
bool generateClientParameters;
bool parametrizeMidLFHE;
std::function<bool(mlir::Pass *)> enablePass;
std::shared_ptr<CompilationContext> compilationContext;
// Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`)
// or indicating that no FHE dialect is used (`NONE`).
enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE };
static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module);
private:
llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride);
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getV0FHEConstraint(CompilationResult &res);
llvm::Error determineFHEParameters(CompilationResult &res);
};
} // namespace zamalang

View File

@@ -4,43 +4,43 @@
#include <llvm/IR/Module.h>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <zamalang/Support/V0Parameters.h>
namespace mlir {
namespace zamalang {
namespace pipeline {
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool debug);
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module);
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);
mlir::LogicalResult
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext,
bool parametrize);
mlir::LogicalResult
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module);
mlir::LogicalResult
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext, bool verbose);
std::unique_ptr<llvm::Module>
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
llvm::LLVMContext &llvmContext,
mlir::ModuleOp &module);
} // namespace pipeline
} // namespace zamalang
} // namespace mlir

View File

@@ -32,6 +32,7 @@ private:
StreamWrap<llvm::raw_ostream> &log_error(void);
StreamWrap<llvm::raw_ostream> &log_verbose(void);
void setupLogging(bool verbose);
bool isVerbose();
} // namespace zamalang
} // namespace mlir

View File

@@ -82,180 +82,61 @@ void CompilerEngine::setMaxEintPrecision(size_t v) {
this->overrideMaxEintPrecision = v;
}
void CompilerEngine::setParametrizeMidLFHE(bool v) {
this->parametrizeMidLFHE = v;
}
void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; }
void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
this->clientParametersFuncName = name.str();
}
// Helper function detecting the FHE dialect with the highest level of
// abstraction used in `module`. If no FHE dialect is used, the
// function returns `CompilerEngine::FHEDialect::NONE`.
CompilerEngine::FHEDialect
CompilerEngine::detectHighestFHEDialect(mlir::ModuleOp module) {
CompilerEngine::FHEDialect highestDialect = CompilerEngine::FHEDialect::NONE;
mlir::TypeID hlfheID =
mlir::TypeID::get<mlir::zamalang::HLFHE::HLFHEDialect>();
mlir::TypeID midlfheID =
mlir::TypeID::get<mlir::zamalang::MidLFHE::MidLFHEDialect>();
mlir::TypeID lowlfheID =
mlir::TypeID::get<mlir::zamalang::LowLFHE::LowLFHEDialect>();
// Helper lambda updating the currently highest dialect if necessary
// by dialect type ID
auto updateDialectFromDialectID = [&](mlir::TypeID dialectID) {
if (dialectID == hlfheID) {
highestDialect = CompilerEngine::FHEDialect::HLFHE;
return true;
} else if (dialectID == lowlfheID &&
highestDialect == CompilerEngine::FHEDialect::NONE) {
highestDialect = CompilerEngine::FHEDialect::LOWLFHE;
} else if (dialectID == midlfheID &&
(highestDialect == CompilerEngine::FHEDialect::NONE ||
highestDialect == CompilerEngine::FHEDialect::LOWLFHE)) {
highestDialect = CompilerEngine::FHEDialect::MIDLFHE;
}
return false;
};
// Helper lambda updating the currently highest dialect if necessary
// by value type
std::function<bool(mlir::Type)> updateDialectFromType =
[&](mlir::Type ty) -> bool {
if (updateDialectFromDialectID(ty.getDialect().getTypeID()))
return true;
if (mlir::TensorType tensorTy = ty.dyn_cast_or_null<mlir::TensorType>())
return updateDialectFromType(tensorTy.getElementType());
return false;
};
module.walk([&](mlir::Operation *op) {
// Check operation itself
if (updateDialectFromDialectID(op->getDialect()->getTypeID()))
return mlir::WalkResult::interrupt();
// Check types of operands
for (mlir::Value operand : op->getOperands()) {
if (updateDialectFromType(operand.getType()))
return mlir::WalkResult::interrupt();
}
// Check types of results
for (mlir::Value res : op->getResults()) {
if (updateDialectFromType(res.getType())) {
return mlir::WalkResult::interrupt();
}
}
return mlir::WalkResult::advance();
});
return highestDialect;
void CompilerEngine::setEnablePass(
std::function<bool(mlir::Pass *)> enablePass) {
this->enablePass = enablePass;
}
// Sets the FHE parameters of `res` either through autodetection or
// fixed constraints provided in
// `CompilerEngine::overrideMaxEintPrecision` and
// `CompilerEngine::overrideMaxMANP`.
//
// Autodetected values can be partially or fully overridden through
// `CompilerEngine::overrideMaxEintPrecision` and
// `CompilerEngine::overrideMaxMANP`.
//
// If `noOverrideAutodetected` is true, autodetected values are not
// overriden and used directly for `res`.
//
// Return an error if autodetection fails.
llvm::Error
CompilerEngine::determineFHEParameters(CompilationResult &res,
bool noOverrideAutodetected) {
// Returns the overwritten V0FHEConstraint or try to compute them from HLFHE
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
CompilerEngine::getV0FHEConstraint(CompilationResult &res) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
// Determine FHE constraints either through autodetection or through
// overridden values
// If the values has been overwritten returns
if (this->overrideMaxEintPrecision.hasValue() &&
this->overrideMaxMANP.hasValue() && !noOverrideAutodetected) {
fheConstraints.emplace(mlir::zamalang::V0FHEConstraint{
this->overrideMaxMANP.hasValue()) {
return mlir::zamalang::V0FHEConstraint{
this->overrideMaxMANP.getValue(),
this->overrideMaxEintPrecision.getValue()});
} else {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(mlirContext,
module);
if (auto err = fheConstraintsOrErr.takeError())
return std::move(err);
if (!fheConstraintsOrErr.get().hasValue()) {
return StreamStringError("Could not determine maximum required precision "
"for encrypted integers and maximum value for "
"the Minimal Arithmetic Noise Padding");
}
if (noOverrideAutodetected)
return llvm::Error::success();
fheConstraints = fheConstraintsOrErr.get();
// Override individual values if requested
if (this->overrideMaxEintPrecision.hasValue())
fheConstraints->p = this->overrideMaxEintPrecision.getValue();
if (this->overrideMaxMANP.hasValue())
fheConstraints->norm2 = this->overrideMaxMANP.getValue();
this->overrideMaxEintPrecision.getValue()};
}
// Else compute constraint from HLFHE
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(
mlirContext, module, enablePass);
if (auto err = fheConstraintsOrErr.takeError())
return std::move(err);
return fheConstraintsOrErr.get();
}
// set the fheContext field if the v0Constraint can be computed
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
auto fheConstraintOrErr = getV0FHEConstraint(res);
if (auto err = fheConstraintOrErr.takeError())
return std::move(err);
if (!fheConstraintOrErr.get().hasValue()) {
return llvm::Error::success();
}
const mlir::zamalang::V0Parameter *fheParams =
getV0Parameter(fheConstraints.getValue());
getV0Parameter(fheConstraintOrErr.get().getValue());
if (!fheParams) {
return StreamStringError()
<< "Could not determine V0 parameters for 2-norm of "
<< fheConstraints->norm2 << " and p of " << fheConstraints->p;
}
res.fheContext.emplace(
mlir::zamalang::V0FHEContext{*fheConstraints, *fheParams});
return llvm::Error::success();
}
// Performs all lowering from HLFHE to the FHE dialect with the lwoest
// level of abstraction that requires FHE parameters.
//
// Returns an error if any of the lowerings fails.
llvm::Error CompilerEngine::lowerParamDependentHalf(Target target,
CompilationResult &res) {
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
mlir::ModuleOp module = res.mlirModuleRef->get();
// HLFHE -> MidLFHE
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, false)
.failed()) {
return StreamStringError("Lowering from HLFHE to MidLFHE failed");
}
if (target == Target::MIDLFHE)
return llvm::Error::success();
// MidLFHE -> LowLFHE
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
mlirContext, module, *res.fheContext, this->parametrizeMidLFHE)
.failed()) {
return StreamStringError("Lowering from MidLFHE to LowLFHE failed");
<< (*fheConstraintOrErr)->norm2 << " and p of "
<< (*fheConstraintOrErr)->p;
}
res.fheContext.emplace(mlir::zamalang::V0FHEContext{
(*fheConstraintOrErr).getValue(), *fheParams});
return llvm::Error::success();
}
@@ -289,43 +170,40 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
res.mlirModuleRef = std::move(mlirModuleRef);
mlir::ModuleOp module = res.mlirModuleRef->get();
if (target == Target::HLFHE || target == Target::ROUND_TRIP)
if (target == Target::ROUND_TRIP)
return res;
// Detect highest FHE dialect and check if FHE parameter
// autodetection / lowering of parameter-dependent dialects can be
// skipped
FHEDialect highestFHEDialect = this->detectHighestFHEDialect(module);
if (highestFHEDialect == FHEDialect::HLFHE ||
highestFHEDialect == FHEDialect::MIDLFHE ||
this->generateClientParameters) {
bool noOverrideAutoDetected = (target == Target::HLFHE_MANP);
if (auto err = this->determineFHEParameters(res, noOverrideAutoDetected))
return std::move(err);
}
// return early if only the MANP pass was requested
if (target == Target::HLFHE_MANP)
// HLFHE High level pass to determine FHE parameters
if (auto err = this->determineFHEParameters(res))
return std::move(err);
if (target == Target::HLFHE)
return res;
if (highestFHEDialect == FHEDialect::HLFHE ||
highestFHEDialect == FHEDialect::MIDLFHE) {
if (llvm::Error err = this->lowerParamDependentHalf(target, res))
return std::move(err);
// HLFHE -> MidLFHE
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module,
enablePass)
.failed()) {
return StreamStringError("Lowering from HLFHE to MidLFHE failed");
}
if (target == Target::MIDLFHE)
return res;
if (target == Target::HLFHE_MANP || target == Target::MIDLFHE ||
target == Target::LOWLFHE)
// MidLFHE -> LowLFHE
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
mlirContext, module, res.fheContext, this->enablePass)
.failed()) {
return StreamStringError("Lowering from MidLFHE to LowLFHE failed");
}
if (target == Target::LOWLFHE)
return res;
// LowLFHE -> Canonical dialects
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module)
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module,
enablePass)
.failed()) {
return StreamStringError(
"Lowering from LowLFHE to canonical MLIR dialects failed");
}
if (target == Target::STD)
return res;
@@ -336,6 +214,10 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
"Generation of client parameters requested, but no function name "
"specified");
}
if (!res.fheContext.hasValue()) {
return StreamStringError(
"Cannot generate client parameters, the fhe context is empty");
}
llvm::Expected<mlir::zamalang::ClientParameters> clientParametersOrErr =
mlir::zamalang::createClientParametersForV0(
@@ -349,7 +231,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
// MLIR canonical dialects -> LLVM Dialect
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
false)
enablePass)
.failed()) {
return StreamStringError("Failed to lower to LLVM dialect");
}

View File

@@ -88,7 +88,7 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(mlirContext);
std::function<llvm::Error(llvm::Module *)> optPipeline =
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =

View File

@@ -19,8 +19,30 @@
namespace mlir {
namespace zamalang {
namespace pipeline {
static void addPotentiallyNestedPass(mlir::PassManager &pm,
std::unique_ptr<Pass> pass) {
static void pipelinePrinting(llvm::StringRef name, mlir::PassManager &pm,
mlir::MLIRContext &ctx) {
if (mlir::zamalang::isVerbose()) {
mlir::zamalang::log_verbose()
<< "##################################################\n"
<< "### " << name << " pipeline\n";
auto isModule = [](mlir::Pass *, mlir::Operation *op) {
return mlir::isa<mlir::ModuleOp>(op);
};
ctx.disableMultithreading(true);
pm.enableIRPrinting(isModule, isModule);
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
}
static void
addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr<Pass> pass,
std::function<bool(mlir::Pass *)> enablePass) {
if (!enablePass(pass.get())) {
return;
}
if (!pass->getOpName() || *pass->getOpName() == "builtin.module") {
pm.addPass(std::move(pass));
} else {
@@ -29,26 +51,20 @@ static void addPotentiallyNestedPass(mlir::PassManager &pm,
}
}
// Creates an instance of the Minimal Arithmetic Noise Padding pass
// and invokes it for all functions of `module`.
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool debug) {
mlir::PassManager pm(&context);
pm.addNestedPass<mlir::FuncOp>(mlir::zamalang::createMANPPass(debug));
return pm.run(module);
}
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
llvm::Optional<size_t> oMax2norm;
llvm::Optional<size_t> oMaxWidth;
mlir::PassManager pm(&context);
addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass());
pipelinePrinting("ComputeFHEConstraintOnHLFHE", pm, context);
addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
unsigned currMaxWidth) {
pm,
mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
unsigned currMaxWidth) {
assert((uint64_t)currMaxWidth < std::numeric_limits<size_t>::max() &&
"Maximum width does not fit into size_t");
@@ -64,15 +80,14 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
oMaxWidth.emplace(width);
}));
}),
enablePass);
if (pm.run(module.getOperation()).failed()) {
return llvm::make_error<llvm::StringError>(
"Failed to determine the maximum Arithmetic Noise Padding and maximum"
"required precision",
llvm::inconvertibleErrorCode());
}
llvm::Optional<mlir::zamalang::V0FHEConstraint> ret;
if (oMax2norm.hasValue() && oMaxWidth.hasValue()) {
@@ -84,86 +99,76 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
return ret;
}
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose) {
mlir::LogicalResult
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("HLFHEToMidLFHE", pm, context);
if (verbose) {
mlir::zamalang::log_verbose()
<< "##################################################\n"
<< "### HLFHE to MidLFHE pipeline\n";
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
llvm::Optional<V0FHEContext> &fheContext,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("MidLFHEToLowLFHE", pm, context);
if (fheContext.hasValue()) {
addPotentiallyNestedPass(
pm,
mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(
fheContext.getValue()),
enablePass);
}
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
addPotentiallyNestedPass(pm,
mlir::zamalang::createConvertHLFHEToMidLFHEPass());
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext,
bool parametrize) {
mlir::PassManager pm(&context);
if (parametrize) {
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(
fheContext));
}
addPotentiallyNestedPass(pm,
mlir::zamalang::createConvertMidLFHEToLowLFHEPass());
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module) {
mlir::LogicalResult
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
pipelinePrinting("LowLFHEToStd", pm, context);
pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass());
return pm.run(module.getOperation());
}
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
mlir::ModuleOp &module,
bool verbose) {
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);
if (verbose) {
mlir::zamalang::log_verbose()
<< "##################################################\n"
<< "### MlirStdsDialectToMlirLLVMDialect pipeline\n";
context.disableMultithreading();
pm.enableIRPrinting();
pm.enableStatistics();
pm.enableTiming();
pm.enableVerifier();
}
pipelinePrinting("StdToLLVM", pm, context);
// Unparametrize LowLFHE
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass());
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), enablePass);
// Bufferize
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass());
addPotentiallyNestedPass(pm, mlir::createStdBufferizePass());
addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass());
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass());
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass());
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass());
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass());
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createStdBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
enablePass);
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
enablePass);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass());
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
enablePass);
return pm.run(module);
}
@@ -181,7 +186,7 @@ lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module) {
std::function<llvm::Error(llvm::Module *)> optPipeline =
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
if (optPipeline(&module))
@@ -190,18 +195,6 @@ mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
return mlir::success();
}
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
mlir::ModuleOp &module,
V0FHEContext &fheContext, bool verbose) {
if (lowerHLFHEToMidLFHE(context, module, verbose).failed() ||
lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() ||
lowerLowLFHEToStd(context, module).failed()) {
return mlir::failure();
} else {
return mlir::success();
}
}
} // namespace pipeline
} // namespace zamalang
} // namespace mlir

View File

@@ -18,5 +18,6 @@ StreamWrap<llvm::raw_ostream> &log_verbose(void) {
// Sets up logging. If `verbose` is false, messages passed to
// `log_verbose` will be discarded.
void setupLogging(bool verbose) { ::mlir::zamalang::verbose = verbose; }
bool isVerbose() { return verbose; }
} // namespace zamalang
} // namespace mlir

View File

@@ -32,7 +32,6 @@
enum Action {
ROUND_TRIP,
DUMP_HLFHE,
DUMP_HLFHE_MANP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
DUMP_STD,
@@ -76,10 +75,10 @@ llvm::cl::opt<std::string> output("o",
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
llvm::cl::init<bool>(false));
llvm::cl::opt<bool> parametrizeMidLFHE(
"parametrize-midlfhe",
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
llvm::cl::init<bool>(true));
llvm::cl::list<std::string> passes(
"passes",
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
static llvm::cl::opt<enum Action> action(
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
@@ -87,9 +86,6 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(
clEnumValN(Action::ROUND_TRIP, "roundtrip",
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
"Dump HLFHE module after running the Minimal "
"Arithmetic Noise Padding pass")),
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe",
"Dump HLFHE module")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
@@ -218,7 +214,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
mlir::LogicalResult
processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
enum Action action, const std::string &jitFuncName,
llvm::ArrayRef<uint64_t> jitArgs, bool parametrizeMidlHFE,
llvm::ArrayRef<uint64_t> jitArgs,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP,
bool verifyDiagnostics, llvm::raw_ostream &os) {
@@ -228,7 +224,13 @@ processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
mlir::zamalang::JitCompilerEngine ce{ccx};
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setParametrizeMidLFHE(parametrizeMidlHFE);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {
return std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; });
});
}
if (overrideMaxEintPrecision.hasValue())
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
@@ -267,9 +269,6 @@ processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
case Action::DUMP_HLFHE:
target = mlir::zamalang::CompilerEngine::Target::HLFHE;
break;
case Action::DUMP_HLFHE_MANP:
target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP;
break;
case Action::DUMP_MIDLFHE:
target = mlir::zamalang::CompilerEngine::Target::MIDLFHE;
break;
@@ -353,7 +352,6 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
return processInputBuffer(
std::move(inputBuffer), cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, os);
},
@@ -362,9 +360,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
} else {
return processInputBuffer(
std::move(file), cmdline::action, cmdline::jitFuncName,
cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, output->os());
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, output->os());
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK: #map0 = affine_map<(d0) -> (d0)>
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
// CHECK-LABEL: module
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s
func @single_zero() -> !HLFHE.eint<2>
{