diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h index a88710bae..c61511e8d 100644 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h @@ -16,7 +16,7 @@ unsigned int getEintPrecision(mlir::Value value); std::unique_ptr createMANPPass(bool debug = false); std::unique_ptr -createMaxMANPPass(std::function setMax); +createMaxMANPPass(std::function setMax); } // namespace concretelang } // namespace mlir diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 978760b71..a7738b555 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -136,6 +136,7 @@ struct FunctionToDag { addDot(dag, val, encrypted_inputs, weightsOpt.getValue()); return; } + // If can't find weights return default leveled op DEBUG("Replace Dot by LevelledOp on " << op); } // default @@ -229,13 +230,16 @@ struct FunctionToDag { return value.isa(); } - std::vector + llvm::Optional> resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { std::vector values; mlir::DenseIntElementsAttr denseVals = cstOp->getAttrOfType("value"); for (llvm::APInt val : denseVals.getValues()) { + if (val.getActiveBits() > 64) { + return llvm::None; + } values.push_back(val.getSExtValue()); } return values; diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index ea3f1d634..0a5209c77 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -168,12 +168,8 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { /// Calculates the square of the value of `i`. static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) { - // Make sure the required number of bits can be represented by the - // `unsigned` argument of `zext`. - assert(i.getActiveBits() < 32 && - "Square of the constant cannot be represented on 64 bits"); - return llvm::APInt(2 * i.getActiveBits(), - i.getZExtValue() * i.getZExtValue()); + llvm::APInt extI(2 * i.getActiveBits(), i.getSExtValue()); + return extI * extI; } /// Calculates the square root of `i` and rounds it to the next highest @@ -1394,14 +1390,11 @@ struct MaxMANPPass : public MaxMANPBase { [&](mlir::Operation *childOp) { this->processOperation(childOp); }); } MaxMANPPass() = delete; - MaxMANPPass(std::function updateMax) - : updateMax(updateMax), maxMANP(llvm::APInt{1, 0, false}), - maxEintWidth(0){}; + MaxMANPPass(std::function updateMax) + : updateMax(updateMax){}; protected: void processOperation(mlir::Operation *op) { - static const llvm::APInt one{1, 1, false}; - bool upd = false; // Process all function arguments and use the default value of 1 // for MANP and the declarend precision @@ -1410,15 +1403,7 @@ protected: for (mlir::BlockArgument blockArg : func.getBody().getArguments()) { if (isEncryptedFunctionParameter(blockArg)) { unsigned int width = fhe::utils::getEintPrecision(blockArg); - - if (this->maxEintWidth < width) { - this->maxEintWidth = width; - } - - if (APIntWidthExtendULT(this->maxMANP, one)) { - this->maxMANP = one; - upd = true; - } + this->updateMax(1, width); } } } @@ -1439,38 +1424,31 @@ protected: } if (eTy) { - if (this->maxEintWidth < eTy.getWidth()) { - this->maxEintWidth = eTy.getWidth(); - upd = true; - } - mlir::IntegerAttr MANP = op->getAttrOfType("MANP"); if (!MANP) { - op->emitError("Maximum Arithmetic Noise Padding value not set"); + op->emitError("2-Norm has not been computed"); this->signalPassFailure(); return; } - if (APIntWidthExtendULT(this->maxMANP, MANP.getValue())) { - this->maxMANP = MANP.getValue(); - upd = true; + auto manp = MANP.getValue(); + if (!manp.isIntN(64)) { + op->emitError("2-Norm cannot be reprensented on 64bits"); + this->signalPassFailure(); + return; } + this->updateMax(manp.getSExtValue(), eTy.getWidth()); } } - - if (upd) - this->updateMax(this->maxMANP, this->maxEintWidth); } - std::function updateMax; - llvm::APInt maxMANP; - unsigned int maxEintWidth; + std::function updateMax; }; } // end anonymous namespace -std::unique_ptr createMaxMANPPass( - std::function updateMax) { +std::unique_ptr +createMaxMANPPass(std::function updateMax) { return std::make_unique(updateMax); } diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index da3b51675..6450cd7eb 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -94,24 +94,14 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, enablePass); addPotentiallyNestedPass( pm, - mlir::concretelang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP, - unsigned currMaxWidth) { - assert((uint64_t)currMaxWidth < std::numeric_limits::max() && - "Maximum width does not fit into size_t"); + mlir::concretelang::createMaxMANPPass( + [&](const uint64_t manp, unsigned width) { + if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp) + oMax2norm.emplace(manp); - assert(sizeof(uint64_t) >= sizeof(size_t) && - currMaxMANP.ult(std::numeric_limits::max()) && - "Maximum MANP does not fit into size_t"); - - size_t manp = (size_t)currMaxMANP.getZExtValue(); - size_t width = (size_t)currMaxWidth; - - if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp) - oMax2norm.emplace(manp); - - if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width) - oMaxWidth.emplace(width); - }), + if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width) + oMaxWidth.emplace(width); + }), enablePass); if (pm.run(module.getOperation()).failed()) { return llvm::make_error( diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index 4bbd79db6..3195f1612 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -183,6 +183,17 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> } +// ----- + +func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> +{ + %cst = arith.constant -1 : i3 + + // %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + + return %0 : !FHE.eint<2> +} // ----- diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 4d59ff786..c2fcdccaa 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -270,7 +270,7 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) %cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3> // sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 56 : ui{{[[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 28 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2>