feat(compiler): Add HLFHE pass selecting maximum MANP and encrypted integer width

This pass calculates the squared Minimal Arithmetic Noise Padding
(MANP) for each operation using the MANP pass and extracts the maximum
(non-squared) Minimal Arithmetic Noise Padding and the maximum
ecrypted integer width from.
This commit is contained in:
Andi Drebes
2021-09-10 16:38:27 +02:00
committed by Quentin Bourgerie
parent 6a6fae96f6
commit 1200a46e49
3 changed files with 73 additions and 0 deletions

View File

@@ -1,11 +1,15 @@
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
#include <functional>
#include <mlir/Pass/Pass.h>
namespace mlir {
namespace zamalang {
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createMaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> setMax);
} // namespace zamalang
} // namespace mlir

View File

@@ -92,4 +92,14 @@ def MANP : FunctionPass<"MANP"> {
}];
}
def MaxMANP : FunctionPass<"MaxMANP"> {
let summary = "Extract maximum HLFHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
let description = [{
This pass calculates the squared Minimal Arithmetic Noise Padding
(MANP) for each operation using the MANP pass and extracts the
maximum (non-squared) Minimal Arithmetic Noise Padding and the
maximum ecrypted integer width from.
}];
}
#endif

View File

@@ -1,3 +1,4 @@
#include <mlir/IR/BuiltinOps.h>
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
@@ -481,5 +482,63 @@ protected:
std::unique_ptr<mlir::Pass> createMANPPass(bool debug) {
return std::make_unique<MANPPass>(debug);
}
namespace {
// For documentation see MANP.td
struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
void runOnFunction() override {
mlir::FuncOp func = getFunction();
func.walk(
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
}
MaxMANPPass() = delete;
MaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> updateMax)
: maxMANP(llvm::APInt{1, 0, false}), maxEintWidth(0),
updateMax(updateMax){};
protected:
void processOperation(mlir::Operation *op) {
for (mlir::OpResult res : op->getResults()) {
mlir::zamalang::HLFHE::EncryptedIntegerType eTy =
res.getType()
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (eTy) {
bool upd = false;
if (this->maxEintWidth < eTy.getWidth()) {
this->maxEintWidth = eTy.getWidth();
upd = true;
}
mlir::IntegerAttr MANP = op->getAttrOfType<mlir::IntegerAttr>("MANP");
if (!MANP) {
op->emitError("Maximum Arithmetic Noise Padding value not set");
this->signalPassFailure();
}
if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) {
this->maxMANP = MANP.getValue();
upd = true;
}
if (upd)
this->updateMax(this->maxMANP, this->maxEintWidth);
}
}
}
std::function<void(const llvm::APInt &, unsigned)> updateMax;
llvm::APInt maxMANP;
unsigned int maxEintWidth;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createMaxMANPPass(
std::function<void(const llvm::APInt &, unsigned)> updateMax) {
return std::make_unique<MaxMANPPass>(updateMax);
}
} // namespace zamalang
} // namespace mlir