diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h index 32c364948..c18334aac 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.h @@ -1,11 +1,15 @@ #ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H #define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H +#include #include namespace mlir { namespace zamalang { std::unique_ptr createMANPPass(bool debug = false); + +std::unique_ptr +createMaxMANPPass(std::function setMax); } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td index cda14a55d..fed6ea590 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td @@ -92,4 +92,14 @@ def MANP : FunctionPass<"MANP"> { }]; } +def MaxMANP : FunctionPass<"MaxMANP"> { + let summary = "Extract maximum HLFHE Minimal Arithmetic Noise Padding and maximum encrypted integer width"; + let description = [{ + This pass calculates the squared Minimal Arithmetic Noise Padding + (MANP) for each operation using the MANP pass and extracts the + maximum (non-squared) Minimal Arithmetic Noise Padding and the + maximum ecrypted integer width from. + }]; +} + #endif diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index c454a054f..4eac5be2e 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -481,5 +482,63 @@ protected: std::unique_ptr createMANPPass(bool debug) { return std::make_unique(debug); } + +namespace { +// For documentation see MANP.td +struct MaxMANPPass : public MaxMANPBase { + void runOnFunction() override { + mlir::FuncOp func = getFunction(); + + func.walk( + [&](mlir::Operation *childOp) { this->processOperation(childOp); }); + } + MaxMANPPass() = delete; + MaxMANPPass(std::function updateMax) + : maxMANP(llvm::APInt{1, 0, false}), maxEintWidth(0), + updateMax(updateMax){}; + +protected: + void processOperation(mlir::Operation *op) { + for (mlir::OpResult res : op->getResults()) { + mlir::zamalang::HLFHE::EncryptedIntegerType eTy = + res.getType() + .dyn_cast_or_null(); + + if (eTy) { + bool upd = false; + if (this->maxEintWidth < eTy.getWidth()) { + this->maxEintWidth = eTy.getWidth(); + upd = true; + } + + mlir::IntegerAttr MANP = op->getAttrOfType("MANP"); + + if (!MANP) { + op->emitError("Maximum Arithmetic Noise Padding value not set"); + this->signalPassFailure(); + } + + if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) { + this->maxMANP = MANP.getValue(); + upd = true; + } + + if (upd) + this->updateMax(this->maxMANP, this->maxEintWidth); + } + } + } + + std::function updateMax; + llvm::APInt maxMANP; + unsigned int maxEintWidth; +}; +} // end anonymous namespace + +std::unique_ptr createMaxMANPPass( + std::function updateMax) { + return std::make_unique(updateMax); +} + } // namespace zamalang } // namespace mlir