feat(compiler): Add action dump-hlfhe-manp

The new option --acion=dump-hlfhe-manp invokes the Minimal Arithmetic
Noise Padding Analysis pass based on the squared 2-norm metric from
`lib/Dialect/HLFHE/Analysis/MANP.cpp` and dumps the module afterwards
with an extra attribute `MANP` for each HLFHE operation.
This commit is contained in:
Andi Drebes
2021-09-24 17:34:59 +02:00
committed by Quentin Bourgerie
parent ed762942c1
commit 54661528a8
5 changed files with 32 additions and 22 deletions

View File

@@ -9,6 +9,8 @@
namespace mlir {
namespace zamalang {
namespace pipeline {
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool debug);
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);

View File

@@ -20,6 +20,7 @@ add_mlir_library(ZamalangSupport
HLFHEToMidLFHE
LowLFHEUnparametrize
MLIRLowerableDialectsToLLVM
HLFHEDialectAnalysis
MLIRExecutionEngine
${LLVM_PTHREAD_LIB}

View File

@@ -10,6 +10,7 @@
#include <mlir/Transforms/Passes.h>
#include <zamalang/Conversion/Passes.h>
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Support/Pipeline.h>
#include <zamalang/Support/logging.h>
@@ -25,6 +26,15 @@ static void addPotentiallyNestedPass(mlir::PassManager &pm,
}
}
// Creates an instance of the Minimal Arithmetic Noise Padding pass
// and invokes it for all functions of `module`.
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool debug) {
mlir::PassManager pm(&context);
pm.addNestedPass<mlir::FuncOp>(mlir::zamalang::createMANPPass(debug));
return pm.run(module);
}
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose) {
mlir::PassManager pm(&context);

View File

@@ -14,7 +14,6 @@ target_link_libraries(zamacompiler
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
HLFHEDialectAnalysis
MLIRIR
MLIRLLVMIR

View File

@@ -15,7 +15,6 @@
#include "mlir/IR/BuiltinOps.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
#include "zamalang/Dialect/HLFHE/Analysis/MANP.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
@@ -31,7 +30,7 @@ enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
enum Action {
ROUND_TRIP,
DEBUG_MANP,
DUMP_HLFHE_MANP,
DUMP_MIDLFHE,
DUMP_LOWLFHE,
DUMP_STD,
@@ -86,9 +85,9 @@ static llvm::cl::opt<enum Action> action(
llvm::cl::values(
clEnumValN(Action::ROUND_TRIP, "roundtrip",
"Parse input module and regenerate textual representation")),
llvm::cl::values(clEnumValN(
Action::DEBUG_MANP, "debug-manp",
"Minimal Arithmetic Noise Padding for each HLFHE operation")),
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
"Dump HLFHE module after running the Minimal "
"Arithmetic Noise Padding pass")),
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
"Lower to MidLFHE and dump result")),
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
@@ -258,28 +257,20 @@ mlir::LogicalResult processInputBuffer(
// a fallthrough mechanism to the next stage. Actions act as exit
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE: {
bool debugMANP = (action == Action::DEBUG_MANP);
case EntryDialect::HLFHE:
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
.failed()) {
return mlir::failure();
}
mlir::LogicalResult manpRes =
mlir::zamalang::invokeMANPPass(module, debugMANP);
if (action == Action::DEBUG_MANP) {
if (manpRes.failed()) {
mlir::zamalang::log_error()
<< "Could not calculate Minimal Arithmetic Noise Padding";
if (!verifyDiagnostics)
return mlir::failure();
} else {
return mlir::success();
}
if (action == Action::DUMP_HLFHE_MANP) {
module.print(os);
return mlir::success();
}
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
.failed())
return mlir::failure();
}
// fallthrough
case EntryDialect::MIDLFHE:
@@ -373,6 +364,13 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
// String for error messages from library functions
std::string errorMessage;
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
cmdline::entryDialect != EntryDialect::HLFHE) {
mlir::zamalang::log_error()
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
return mlir::failure();
}
if (cmdline::action == Action::JIT_INVOKE &&
cmdline::entryDialect != EntryDialect::HLFHE &&
cmdline::entryDialect != EntryDialect::MIDLFHE &&