fix: Do not assert fail with too large weight and fix computation of 2-Norm with negative weigth (close #892)

This commit is contained in:
Quentin Bourgerie
2023-01-19 13:31:38 +01:00
parent 227a0747fa
commit 49b8bf484c
6 changed files with 40 additions and 57 deletions

View File

@@ -136,6 +136,7 @@ struct FunctionToDag {
addDot(dag, val, encrypted_inputs, weightsOpt.getValue());
return;
}
// If can't find weights return default leveled op
DEBUG("Replace Dot by LevelledOp on " << op);
}
// default
@@ -229,13 +230,16 @@ struct FunctionToDag {
return value.isa<mlir::BlockArgument>();
}
std::vector<std::int64_t>
llvm::Optional<std::vector<std::int64_t>>
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
std::vector<std::int64_t> values;
mlir::DenseIntElementsAttr denseVals =
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
if (val.getActiveBits() > 64) {
return llvm::None;
}
values.push_back(val.getSExtValue());
}
return values;

View File

@@ -168,12 +168,8 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
/// Calculates the square of the value of `i`.
static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(i.getActiveBits() < 32 &&
"Square of the constant cannot be represented on 64 bits");
return llvm::APInt(2 * i.getActiveBits(),
i.getZExtValue() * i.getZExtValue());
llvm::APInt extI(2 * i.getActiveBits(), i.getSExtValue());
return extI * extI;
}
/// Calculates the square root of `i` and rounds it to the next highest
@@ -1394,14 +1390,11 @@ struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
[&](mlir::Operation *childOp) { this->processOperation(childOp); });
}
MaxMANPPass() = delete;
MaxMANPPass(std::function<void(const llvm::APInt &, unsigned)> updateMax)
: updateMax(updateMax), maxMANP(llvm::APInt{1, 0, false}),
maxEintWidth(0){};
MaxMANPPass(std::function<void(const uint64_t, unsigned)> updateMax)
: updateMax(updateMax){};
protected:
void processOperation(mlir::Operation *op) {
static const llvm::APInt one{1, 1, false};
bool upd = false;
// Process all function arguments and use the default value of 1
// for MANP and the declarend precision
@@ -1410,15 +1403,7 @@ protected:
for (mlir::BlockArgument blockArg : func.getBody().getArguments()) {
if (isEncryptedFunctionParameter(blockArg)) {
unsigned int width = fhe::utils::getEintPrecision(blockArg);
if (this->maxEintWidth < width) {
this->maxEintWidth = width;
}
if (APIntWidthExtendULT(this->maxMANP, one)) {
this->maxMANP = one;
upd = true;
}
this->updateMax(1, width);
}
}
}
@@ -1439,38 +1424,31 @@ protected:
}
if (eTy) {
if (this->maxEintWidth < eTy.getWidth()) {
this->maxEintWidth = eTy.getWidth();
upd = true;
}
mlir::IntegerAttr MANP = op->getAttrOfType<mlir::IntegerAttr>("MANP");
if (!MANP) {
op->emitError("Maximum Arithmetic Noise Padding value not set");
op->emitError("2-Norm has not been computed");
this->signalPassFailure();
return;
}
if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) {
this->maxMANP = MANP.getValue();
upd = true;
auto manp = MANP.getValue();
if (!manp.isIntN(64)) {
op->emitError("2-Norm cannot be reprensented on 64bits");
this->signalPassFailure();
return;
}
this->updateMax(manp.getSExtValue(), eTy.getWidth());
}
}
if (upd)
this->updateMax(this->maxMANP, this->maxEintWidth);
}
std::function<void(const llvm::APInt &, unsigned)> updateMax;
llvm::APInt maxMANP;
unsigned int maxEintWidth;
std::function<void(const uint64_t, unsigned)> updateMax;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createMaxMANPPass(
std::function<void(const llvm::APInt &, unsigned)> updateMax) {
std::unique_ptr<mlir::Pass>
createMaxMANPPass(std::function<void(const uint64_t, unsigned)> updateMax) {
return std::make_unique<MaxMANPPass>(updateMax);
}