diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.td b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.td index 5511dffa9..7a848a94e 100644 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.td +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.td @@ -46,6 +46,7 @@ def MANP : Pass<"MANP", "::mlir::func::FuncOp"> { - FHE.apply_lookup_table -> FHELinalg.dot_eint_int([LUT result], [1]) - FHE.zero() -> FHELinalg.dot_eint_int([encrypted 0], [1]) - FHE.add_eint_int(e, c) -> FHELinalg.dot_eint_int([e, 1], [1, c]) + with the encrypted 1 trivialy encrypted, i.e. without noise so 1xc is not take into account - FHE.add_eint(e0, e1) -> FHELinalg.dot_eint_int([e0, e1], [1, 1]) - FHE.sub_int_eint(c, e) -> FHELinalg.dot_eint_int([e, c], [1, -1]) - FHE.neg_eint(e) -> FHELinalg.dot_eint_int([e], [-1]) @@ -84,7 +85,7 @@ def MANP : Pass<"MANP", "::mlir::func::FuncOp"> { - FHE.zero() -> 1 - FHELinalg.dot_eint_int([e0, e1, ...], [c0, c1, ...]) -> c0*c0*sqN(e0) + c1*c1*sqN(e1) + ... - - FHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c + - FHE.add_eint_int(e, c) -> 1*1*sqN(e) = sqN(e) - FHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2) - FHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c - FHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e) @@ -96,7 +97,8 @@ def MANP : Pass<"MANP", "::mlir::func::FuncOp"> { } def MaxMANP : Pass<"MaxMANP", "::mlir::func::FuncOp"> { - let summary = "Extract maximum FHE Minimal Arithmetic Noise Padding and maximum encrypted integer width"; + let summary = "Extract maximum FHE Minimal Arithmetic Noise Padding and " + "maximum encrypted integer width"; let description = [{ This pass calculates the squared Minimal Arithmetic Noise Padding (MANP) for each operation using the MANP pass and extracts the diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 513a13a3e..6fd26d894 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -320,34 +320,14 @@ static llvm::APInt getSqMANP( static llvm::APInt getSqMANP( mlir::concretelang::FHE::AddEintIntOp op, llvm::ArrayRef *> operandMANPs) { - mlir::Type iTy = op->getOpOperand(1).get().getType(); - - assert(iTy.isSignlessInteger() && - "Only additions with signless integers are currently allowed"); - assert( operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().hasValue() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - if (cstOp) { - // For a constant operand use actual constant to calculate 2-norm - mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); - } else { - // For a dynamic operand conservatively assume that the value is - // the maximum for the integer width - sqNorm = conservativeIntNorm2Sq(iTy); - } - - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation @@ -372,10 +352,6 @@ static llvm::APInt getSqMANP( static llvm::APInt getSqMANP( mlir::concretelang::FHE::SubIntEintOp op, llvm::ArrayRef *> operandMANPs) { - mlir::Type iTy = op->getOpOperand(0).get().getType(); - - assert(iTy.isSignlessInteger() && - "Only subtractions with signless integers are currently allowed"); assert( operandMANPs.size() == 2 && @@ -383,22 +359,8 @@ static llvm::APInt getSqMANP( "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(0).get().getDefiningOp()); - - if (cstOp) { - // For constant plaintext operands simply use the constant value - mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); - } else { - // For dynamic plaintext operands conservatively assume that the integer has - // its maximum possible value - sqNorm = conservativeIntNorm2Sq(iTy); - } - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation @@ -406,10 +368,6 @@ static llvm::APInt getSqMANP( static llvm::APInt getSqMANP( mlir::concretelang::FHE::SubEintIntOp op, llvm::ArrayRef *> operandMANPs) { - mlir::Type iTy = op->getOpOperand(1).get().getType(); - - assert(iTy.isSignlessInteger() && - "Only subtractions with signless integers are currently allowed"); assert( operandMANPs.size() == 2 && @@ -417,22 +375,8 @@ static llvm::APInt getSqMANP( "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - - if (cstOp) { - // For constant plaintext operands simply use the constant value - mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); - } else { - // For dynamic plaintext operands conservatively assume that the integer has - // its maximum possible value - sqNorm = conservativeIntNorm2Sq(iTy); - } - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } /// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation @@ -509,39 +453,14 @@ static llvm::APInt getSqMANP( mlir::concretelang::FHELinalg::AddEintIntOp op, llvm::ArrayRef *> operandMANPs) { - mlir::RankedTensorType op1Ty = - op->getOpOperand(1).get().getType().cast(); - - mlir::Type iTy = op1Ty.getElementType(); - - assert(iTy.isSignlessInteger() && - "Only additions with signless integers are currently allowed"); - assert( operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().hasValue() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - if (denseVals) { - // For a constant operand use actual constant to calculate 2-norm - sqNorm = maxIntNorm2Sq(denseVals); - } else { - // For a dynamic operand conservatively assume that the value is - // the maximum for the integer width - sqNorm = conservativeIntNorm2Sq(iTy); - } - - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } static llvm::APInt getSqMANP( @@ -565,74 +484,28 @@ static llvm::APInt getSqMANP( mlir::concretelang::FHELinalg::SubIntEintOp op, llvm::ArrayRef *> operandMANPs) { - mlir::RankedTensorType op0Ty = - op->getOpOperand(0).get().getType().cast(); - - mlir::Type iTy = op0Ty.getElementType(); - - assert(iTy.isSignlessInteger() && - "Only subtractions with signless integers are currently allowed"); - assert( operandMANPs.size() == 2 && operandMANPs[1]->getValue().getMANP().hasValue() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(0).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - if (denseVals) { - sqNorm = maxIntNorm2Sq(denseVals); - } else { - // For dynamic plaintext operands conservatively assume that the integer has - // its maximum possible value - sqNorm = conservativeIntNorm2Sq(iTy); - } - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } static llvm::APInt getSqMANP( mlir::concretelang::FHELinalg::SubEintIntOp op, llvm::ArrayRef *> operandMANPs) { - mlir::RankedTensorType op1Ty = - op->getOpOperand(1).get().getType().cast(); - - mlir::Type iTy = op1Ty.getElementType(); - - assert(iTy.isSignlessInteger() && - "Only subtractions with signless integers are currently allowed"); - assert( operandMANPs.size() == 2 && operandMANPs[0]->getValue().getMANP().hasValue() && "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); - llvm::APInt sqNorm; - mlir::arith::ConstantOp cstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(1).get().getDefiningOp()); - mlir::DenseIntElementsAttr denseVals = - cstOp ? cstOp->getAttrOfType("value") - : nullptr; - - if (denseVals) { - sqNorm = maxIntNorm2Sq(denseVals); - } else { - // For dynamic plaintext operands conservatively assume that the integer has - // its maximum possible value - sqNorm = conservativeIntNorm2Sq(iTy); - } - return APIntWidthExtendUAdd(sqNorm, eNorm); + return eNorm; } static llvm::APInt getSqMANP( diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index 30a14e4e2..4bbd79db6 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -24,7 +24,7 @@ func.func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -36,7 +36,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -46,8 +46,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -69,7 +68,7 @@ func.func @single_cst_sub_int_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%cst, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -81,7 +80,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%cst, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -91,8 +90,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -104,7 +102,7 @@ func.func @single_cst_sub_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -116,7 +114,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -126,8 +124,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -215,13 +212,13 @@ func.func @chain_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %cst2 = arith.constant 2 : i4 %cst3 = arith.constant 1 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %0 = "FHE.add_eint_int"(%e, %cst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %1 = "FHE.add_eint_int"(%0, %cst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %2 = "FHE.add_eint_int"(%1, %cst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %3 = "FHE.add_eint_int"(%2, %cst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> return %3 : !FHE.eint<3> @@ -236,13 +233,13 @@ func.func @dag_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %Acst2 = arith.constant 2 : i4 %Acst3 = arith.constant 1 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %A0 = "FHE.add_eint_int"(%e, %Acst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %A1 = "FHE.add_eint_int"(%A0, %Acst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %A2 = "FHE.add_eint_int"(%A1, %Acst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %A3 = "FHE.add_eint_int"(%A2, %Acst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> %Bcst0 = arith.constant 1 : i4 @@ -252,20 +249,20 @@ func.func @dag_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %Bcst4 = arith.constant 4 : i4 %Bcst5 = arith.constant 7 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B0 = "FHE.add_eint_int"(%e, %Bcst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B1 = "FHE.add_eint_int"(%B0, %Bcst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B2 = "FHE.add_eint_int"(%B1, %Bcst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B3 = "FHE.add_eint_int"(%B2, %Bcst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint_int"(%[[V3]], %[[op1:.*]]) {MANP = 10 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint_int"(%[[V3]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B4 = "FHE.add_eint_int"(%B3, %Bcst4) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint_int"(%[[V4]], %[[op1:.*]]) {MANP = 13 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint_int"(%[[V4]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> %B5 = "FHE.add_eint_int"(%B4, %Bcst5) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[op1:.*]]) {MANP = 15 : ui{{[0-9]+}}} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> + // CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> %res = "FHE.add_eint"(%B5, %A3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> return %A3 : !FHE.eint<3> @@ -297,9 +294,9 @@ func.func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst0 = arith.constant 3 : i3 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %cst0) : (!FHE.eint<2>, i3) -> !FHE.eint<2> - // CHECK-NEXT: %[[ret:.*]] = "FHE.neg_eint"(%[[V0]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: %[[ret:.*]] = "FHE.neg_eint"(%[[V0]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2> %1 = "FHE.neg_eint"(%0) : (!FHE.eint<2>) -> !FHE.eint<2> return %1 : !FHE.eint<2> diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index a1a2cbcd7..1bb848b1c 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -4,7 +4,7 @@ func.func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -17,7 +17,7 @@ func.func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -26,7 +26,7 @@ func.func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) // ----- func.func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -48,7 +48,7 @@ func.func @single_cst_sub_int_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -61,7 +61,7 @@ func.func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -73,7 +73,7 @@ func.func @single_cst_sub_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -86,7 +86,7 @@ func.func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -116,7 +116,8 @@ func.func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> func.func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // sqrt(1 + (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -166,13 +167,13 @@ func.func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint< %cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi4> %cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi4> %cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi4> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> %1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> %2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> %3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> return %3 : tensor<8x!FHE.eint<3>> } @@ -182,9 +183,9 @@ func.func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint< func.func @chain_add_eint_int_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { %cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %1 = "FHELinalg.neg_eint"(%0) : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %1 : tensor<8x!FHE.eint<2>> } @@ -556,7 +557,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint< %z = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>> %a = arith.constant dense<[[4, 6, 5], [2, 6, 3], [5, 6, 1], [5, 5, 3]]> : tensor<4x3xi8> - // CHECK: {MANP = 7 : ui{{[0-9]+}}} + // CHECK: {MANP = 1 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%z, %a) : (tensor<4x3x!FHE.eint<7>>, tensor<4x3xi8>) -> tensor<4x3x!FHE.eint<7>> // =============================== @@ -566,7 +567,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint< [2, 1, 5] > : tensor<3xi8> - // CHECK: MANP = 34 : ui{{[0-9]+}} + // CHECK: MANP = 6 : ui{{[0-9]+}} %2 = "FHELinalg.matmul_eint_int"(%0, %1) : (tensor<4x3x!FHE.eint<7>>, tensor<3xi8>) -> tensor<4x!FHE.eint<7>> // =============================== @@ -823,7 +824,7 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<3x2x!FHE.eint< %z = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> %a = arith.constant dense<[[4, 6], [2, 6], [5, 6]]> : tensor<3x2xi8> - // CHECK: {MANP = 7 : ui{{[0-9]+}}} + // CHECK: {MANP = 1 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%z, %a) : (tensor<3x2x!FHE.eint<7>>, tensor<3x2xi8>) -> tensor<3x2x!FHE.eint<7>> // =============================== @@ -833,7 +834,7 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<3x2x!FHE.eint< [2, 1, 5] > : tensor<3xi8> - // CHECK: MANP = 34 : ui{{[0-9]+}} + // CHECK: MANP = 6 : ui{{[0-9]+}} %2 = "FHELinalg.matmul_int_eint"(%1, %0) : (tensor<3xi8>, tensor<3x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> // =============================== diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir index 25bc57f7c..f3ece1225 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir @@ -15,11 +15,11 @@ func.func @tensor_from_elements_2(%a: !FHE.eint<2>, %b: !FHE.eint<2>, %c: !FHE.e { %cst = arith.constant 3 : i3 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> - %0 = "FHE.add_eint_int"(%a, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[V0:.*]] = "FHE.mul_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + %0 = "FHE.mul_eint_int"(%a, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> - // The MANP value is 4, i.e. the max of all of its operands - // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // The MANP value is 3, i.e. the max of all of its operands + // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 3 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> %1 = tensor.from_elements %0, %b, %c, %d : tensor<4x!FHE.eint<2>> return %1 : tensor<4x!FHE.eint<2>> @@ -44,9 +44,9 @@ func.func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> { %c1 = arith.constant 1 : index %c3 = arith.constant dense<3> : tensor<4xi3> - // CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} - %0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + %0 = "FHELinalg.mul_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 3 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> %2 = tensor.extract %0[%c1] : tensor<4x!FHE.eint<2>> return %2 : !FHE.eint<2> @@ -68,10 +68,10 @@ func.func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.e { %c3 = arith.constant dense <3> : tensor<4xi3> - // CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} - %0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + %0 = "FHELinalg.mul_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> + // CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 3 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> %2 = tensor.extract_slice %0[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> return %2 : tensor<2x!FHE.eint<2>> @@ -99,9 +99,9 @@ func.func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8 func.func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 8 : ui{{[0-9]+}}} - %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> - // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 8 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 7 : ui{{[0-9]+}}} + %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> + // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 7 : ui{{[0-9]+}}} %1 = tensor.collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>> return %1 : tensor<2x8x!FHE.eint<2>> } @@ -118,9 +118,9 @@ func.func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x! func.func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 8 : ui{{[0-9]+}}} - %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> - // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 8 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 7 : ui{{[0-9]+}}} + %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> + // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 7 : ui{{[0-9]+}}} %1 = tensor.expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>> return %1 : tensor<2x2x4x!FHE.eint<2>> }