mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): Add action dump-hlfhe-manp
The new option --acion=dump-hlfhe-manp invokes the Minimal Arithmetic Noise Padding Analysis pass based on the squared 2-norm metric from `lib/Dialect/HLFHE/Analysis/MANP.cpp` and dumps the module afterwards with an extra attribute `MANP` for each HLFHE operation.
This commit is contained in:
committed by
Quentin Bourgerie
parent
ed762942c1
commit
54661528a8
@@ -9,6 +9,8 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool debug);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
|
||||
@@ -20,6 +20,7 @@ add_mlir_library(ZamalangSupport
|
||||
HLFHEToMidLFHE
|
||||
LowLFHEUnparametrize
|
||||
MLIRLowerableDialectsToLLVM
|
||||
HLFHEDialectAnalysis
|
||||
|
||||
MLIRExecutionEngine
|
||||
${LLVM_PTHREAD_LIB}
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <zamalang/Conversion/Passes.h>
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Support/Pipeline.h>
|
||||
#include <zamalang/Support/logging.h>
|
||||
|
||||
@@ -25,6 +26,15 @@ static void addPotentiallyNestedPass(mlir::PassManager &pm,
|
||||
}
|
||||
}
|
||||
|
||||
// Creates an instance of the Minimal Arithmetic Noise Padding pass
|
||||
// and invokes it for all functions of `module`.
|
||||
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool debug) {
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::zamalang::createMANPPass(debug));
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
@@ -14,7 +14,6 @@ target_link_libraries(zamacompiler
|
||||
LowLFHEDialect
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
HLFHEDialectAnalysis
|
||||
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "zamalang/Dialect/HLFHE/Analysis/MANP.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
@@ -31,7 +30,7 @@ enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DEBUG_MANP,
|
||||
DUMP_HLFHE_MANP,
|
||||
DUMP_MIDLFHE,
|
||||
DUMP_LOWLFHE,
|
||||
DUMP_STD,
|
||||
@@ -86,9 +85,9 @@ static llvm::cl::opt<enum Action> action(
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::ROUND_TRIP, "roundtrip",
|
||||
"Parse input module and regenerate textual representation")),
|
||||
llvm::cl::values(clEnumValN(
|
||||
Action::DEBUG_MANP, "debug-manp",
|
||||
"Minimal Arithmetic Noise Padding for each HLFHE operation")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
|
||||
"Dump HLFHE module after running the Minimal "
|
||||
"Arithmetic Noise Padding pass")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
|
||||
"Lower to MidLFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
|
||||
@@ -258,28 +257,20 @@ mlir::LogicalResult processInputBuffer(
|
||||
// a fallthrough mechanism to the next stage. Actions act as exit
|
||||
// points from the pipeline.
|
||||
switch (entryDialect) {
|
||||
case EntryDialect::HLFHE: {
|
||||
bool debugMANP = (action == Action::DEBUG_MANP);
|
||||
case EntryDialect::HLFHE:
|
||||
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult manpRes =
|
||||
mlir::zamalang::invokeMANPPass(module, debugMANP);
|
||||
|
||||
if (action == Action::DEBUG_MANP) {
|
||||
if (manpRes.failed()) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Could not calculate Minimal Arithmetic Noise Padding";
|
||||
|
||||
if (!verifyDiagnostics)
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return mlir::success();
|
||||
}
|
||||
if (action == Action::DUMP_HLFHE_MANP) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::MIDLFHE:
|
||||
@@ -373,6 +364,13 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
// String for error messages from library functions
|
||||
std::string errorMessage;
|
||||
|
||||
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::JIT_INVOKE &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::MIDLFHE &&
|
||||
|
||||
Reference in New Issue
Block a user