mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add HLFHE pass selecting maximum MANP and encrypted integer width
This pass calculates the squared Minimal Arithmetic Noise Padding (MANP) for each operation using the MANP pass and extracts the maximum (non-squared) Minimal Arithmetic Noise Padding and the maximum ecrypted integer width from.
This commit is contained in:
committed by
Quentin Bourgerie
parent
6a6fae96f6
commit
1200a46e49
@@ -1,11 +1,15 @@
|
||||
#ifndef ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
#define ZAMALANG_DIALECT_HLFHE_ANALYSIS_MANP_H
|
||||
|
||||
#include <functional>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug = false);
|
||||
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createMaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> setMax);
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -92,4 +92,14 @@ def MANP : FunctionPass<"MANP"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def MaxMANP : FunctionPass<"MaxMANP"> {
|
||||
let summary = "Extract maximum HLFHE Minimal Arithmetic Noise Padding and maximum encrypted integer width";
|
||||
let description = [{
|
||||
This pass calculates the squared Minimal Arithmetic Noise Padding
|
||||
(MANP) for each operation using the MANP pass and extracts the
|
||||
maximum (non-squared) Minimal Arithmetic Noise Padding and the
|
||||
maximum ecrypted integer width from.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
@@ -481,5 +482,63 @@ protected:
|
||||
std::unique_ptr<mlir::Pass> createMANPPass(bool debug) {
|
||||
return std::make_unique<MANPPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// For documentation see MANP.td
|
||||
struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp func = getFunction();
|
||||
|
||||
func.walk(
|
||||
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
|
||||
}
|
||||
MaxMANPPass() = delete;
|
||||
MaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> updateMax)
|
||||
: maxMANP(llvm::APInt{1, 0, false}), maxEintWidth(0),
|
||||
updateMax(updateMax){};
|
||||
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
for (mlir::OpResult res : op->getResults()) {
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType eTy =
|
||||
res.getType()
|
||||
.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
|
||||
if (eTy) {
|
||||
bool upd = false;
|
||||
if (this->maxEintWidth < eTy.getWidth()) {
|
||||
this->maxEintWidth = eTy.getWidth();
|
||||
upd = true;
|
||||
}
|
||||
|
||||
mlir::IntegerAttr MANP = op->getAttrOfType<mlir::IntegerAttr>("MANP");
|
||||
|
||||
if (!MANP) {
|
||||
op->emitError("Maximum Arithmetic Noise Padding value not set");
|
||||
this->signalPassFailure();
|
||||
}
|
||||
|
||||
if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) {
|
||||
this->maxMANP = MANP.getValue();
|
||||
upd = true;
|
||||
}
|
||||
|
||||
if (upd)
|
||||
this->updateMax(this->maxMANP, this->maxEintWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void(const llvm::APInt &, unsigned)> updateMax;
|
||||
llvm::APInt maxMANP;
|
||||
unsigned int maxEintWidth;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMaxMANPPass(
|
||||
std::function<void(const llvm::APInt &, unsigned)> updateMax) {
|
||||
return std::make_unique<MaxMANPPass>(updateMax);
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user