diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index b4b02f980..521a842f1 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -166,14 +166,14 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { return ie * ie; } -/// Calculates the square of the absolute value of `i`. +/// Calculates the square of the value of `i`. static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) { // Make sure the required number of bits can be represented by the // `unsigned` argument of `zext`. assert(i.getActiveBits() < 32 && "Square of the constant cannot be represented on 64 bits"); return llvm::APInt(2 * i.getActiveBits(), - i.abs().getZExtValue() * i.abs().getZExtValue()); + i.getZExtValue() * i.getZExtValue()); } /// Calculates the square root of `i` and rounds it to the next highest @@ -277,7 +277,7 @@ static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { assert(t.isSignlessInteger() && "Type must be a signless integer type"); assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); - llvm::APInt maxVal = APInt::getSignedMaxValue(t.getIntOrFloatBitWidth()); + llvm::APInt maxVal = APInt::getMaxValue(t.getIntOrFloatBitWidth()); return APIntWidthExtendUnsignedSq(maxVal); } diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index 0d1664270..b42876a84 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -36,7 +36,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -47,7 +47,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -81,7 +81,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%cst, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -92,7 +92,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -116,7 +116,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -127,7 +127,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index b755583e7..150602d06 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -26,8 +26,7 @@ func.func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) // ----- func.func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // sqrt(1 + (2^2-1)^2) = 3..16 - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -117,8 +116,7 @@ func.func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> func.func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -140,10 +138,10 @@ func.func @single_cst_mul_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. func.func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { - %cst1 = arith.constant 1 : i3 + %cst1 = arith.constant 2 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -208,7 +206,7 @@ func.func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.e func.func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> { %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> @@ -227,7 +225,7 @@ func.func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor< // ----- func.func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> { - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> @@ -266,12 +264,12 @@ func.func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FH func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> %cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3> // sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 56 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -282,11 +280,11 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 7 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> // sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 18 : ui{{[[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 42 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -304,7 +302,7 @@ func.func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tenso // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 + 1 = 10 // ceil(sqrt(65)) = 4 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -320,7 +318,7 @@ func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tenso // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 10 + 9 = 19 // ceil(sqrt(19)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -554,6 +552,7 @@ func.func @matmul_eint_int_cst() -> tensor<4x3x!FHE.eint<7>> { // ----- func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<7>> { + // CHECK: {MANP = 1 : ui{{[0-9]+}}} %z = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>> %a = arith.constant dense<[[4, 6, 5], [2, 6, 3], [5, 6, 1], [5, 5, 3]]> : tensor<4x3xi8> @@ -563,7 +562,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint< // =============================== %1 = arith.constant dense< - // ceil(sqrt(37 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(1111)) = 34 + // ceil(sqrt(1 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(31)) = 6 [2, 1, 5] > : tensor<3xi8> @@ -587,7 +586,7 @@ func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 64 + 1 = 10 // ceil(sqrt(65)) = 4 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 8 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -603,7 +602,7 @@ func.func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 10 + 9 = 19 // ceil(sqrt(129)) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 10 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -830,7 +829,7 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<3x2x!FHE.eint< // =============================== %1 = arith.constant dense< - // ceil(sqrt(37 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(1111)) = 34 + // ceil(sqrt(37 * (2^2 + 1^2 + 5^2) + 1)) = ceil(sqrt(31)) = 6 [2, 1, 5] > : tensor<3xi8> @@ -1046,7 +1045,7 @@ func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 64 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 128 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> @@ -1057,7 +1056,7 @@ func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : ten func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { %bias = arith.constant dense<[5]> : tensor<1xi3> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 15 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -1067,7 +1066,7 @@ func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tens // ----- func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 16 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -1077,9 +1076,10 @@ func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weigh // ----- func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 26 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> return %0 : tensor<100x5x2x2x!FHE.eint<2>> } + diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir index ec5c4b6da..be65deb0d 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir @@ -99,9 +99,9 @@ func.func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8 func.func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 8 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> - // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}} + // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 8 : ui{{[0-9]+}}} %1 = tensor.collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>> return %1 : tensor<2x8x!FHE.eint<2>> } @@ -118,9 +118,9 @@ func.func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x! func.func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 8 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> - // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}} + // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 8 : ui{{[0-9]+}}} %1 = tensor.expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>> return %1 : tensor<2x2x4x!FHE.eint<2>> } diff --git a/compiler/tests/end_to_end_fixture/end_to_end_leveled.yaml b/compiler/tests/end_to_end_fixture/end_to_end_leveled.yaml index d3991e8a2..74fab6664 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_leveled.yaml +++ b/compiler/tests/end_to_end_fixture/end_to_end_leveled.yaml @@ -5974,29 +5974,6 @@ tests: outputs: - scalar: 536870911 --- -description: add_eint_int_arg_29bits -program: | - func.func @main(%arg0: !FHE.eint<29>, %arg1: i30) -> !FHE.eint<29> { - %0 = "FHE.add_eint_int"(%arg0, %arg1): (!FHE.eint<29>, i30) -> (!FHE.eint<29>) - return %0: !FHE.eint<29> - } -tests: - - inputs: - - scalar: 536870910 - - scalar: 1 - outputs: - - scalar: 536870911 - - inputs: - - scalar: 536870911 - - scalar: 0 - outputs: - - scalar: 536870911 - - inputs: - - scalar: 268435455 - - scalar: 268435456 - outputs: - - scalar: 536870911 ---- description: add_eint_29_bits program: | func.func @main(%arg0: !FHE.eint<29>, %arg1: !FHE.eint<29>) -> !FHE.eint<29> { @@ -6023,29 +6000,6 @@ tests: outputs: - scalar: 268435456 --- -description: sub_eint_int_arg_29bits -program: | - func.func @main(%arg0: !FHE.eint<29>, %arg1: i30) -> !FHE.eint<29> { - %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<29>, i30) -> (!FHE.eint<29>) - return %1: !FHE.eint<29> - } -tests: - - inputs: - - scalar: 536870911 - - scalar: 536870911 - outputs: - - scalar: 0 - - inputs: - - scalar: 536870911 - - scalar: 0 - outputs: - - scalar: 536870911 - - inputs: - - scalar: 536870910 - - scalar: 268435455 - outputs: - - scalar: 268435455 ---- description: sub_int_eint_cst_29bits program: | func.func @main(%arg0: !FHE.eint<29>) -> !FHE.eint<29> { @@ -6059,29 +6013,6 @@ tests: outputs: - scalar: 0 --- -description: sub_int_eint_arg_29bits -program: | - func.func @main(%arg0: i30, %arg1: !FHE.eint<29>) -> !FHE.eint<29> { - %1 = "FHE.sub_int_eint"(%arg0, %arg1): (i30, !FHE.eint<29>) -> (!FHE.eint<29>) - return %1: !FHE.eint<29> - } -tests: - - inputs: - - scalar: 536870911 - - scalar: 536870911 - outputs: - - scalar: 0 - - inputs: - - scalar: 536870911 - - scalar: 0 - outputs: - - scalar: 536870911 - - inputs: - - scalar: 536870910 - - scalar: 268435455 - outputs: - - scalar: 268435455 ---- description: sub_eint_29bits program: | func.func @main(%arg0: !FHE.eint<29>, %arg1: !FHE.eint<29>) -> !FHE.eint<29> { @@ -6122,34 +6053,6 @@ tests: outputs: - scalar: 536870910 --- -description: mul_eint_int_arg_29bits -program: | - func.func @main(%arg0: !FHE.eint<29>, %arg1: i30) -> !FHE.eint<29> { - %1 = "FHE.mul_eint_int"(%arg0, %arg1): (!FHE.eint<29>, i30) -> (!FHE.eint<29>) - return %1: !FHE.eint<29> - } -tests: - - inputs: - - scalar: 0 - - scalar: 536870911 - outputs: - - scalar: 0 - - inputs: - - scalar: 536870911 - - scalar: 0 - outputs: - - scalar: 0 - - inputs: - - scalar: 1 - - scalar: 536870911 - outputs: - - scalar: 536870911 - - inputs: - - scalar: 536870911 - - scalar: 1 - outputs: - - scalar: 536870911 ---- --- description: identity_30bits program: | diff --git a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py index b038db3c3..7051ac7e6 100644 --- a/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py +++ b/compiler/tests/end_to_end_fixture/end_to_end_leveled_gen.py @@ -57,8 +57,8 @@ def main(): print(" - scalar: {0}".format(max_value)) print("---") # add_eint_int_arg - if p <= 29: - # above 29 bits the *arg test doesn't have solution + if p <= 28: + # above 28 bits the *arg test doesn't have solution # TODO: Make a test that test that print("description: add_eint_int_arg_{0}bits".format(p)) print("program: |") @@ -119,8 +119,8 @@ def main(): print(" - scalar: {0}".format(max_value-max_constant)) print("---") # sub_eint_int_arg - if p <= 29: - # above 29 bits the *arg test doesn't have solution + if p <= 28: + # above 28 bits the *arg test doesn't have solution # TODO: Make a test that test that print("description: sub_eint_int_arg_{0}bits".format(p)) print("program: |") @@ -165,8 +165,8 @@ def main(): print(" - scalar: 0") print("---") # sub_int_eint_arg - if p <= 29: - # above 29 bits the *arg test doesn't have solution + if p <= 28: + # above 28 bits the *arg test doesn't have solution # TODO: Make a test that test that print("description: sub_int_eint_arg_{0}bits".format(p)) print("program: |") @@ -240,8 +240,8 @@ def main(): print(" - scalar: {0}".format(max_value - 1)) print("---") # mul_eint_int_arg - if p <= 29: - # above 29 bits the *arg test doesn't have solution + if p <= 28: + # above 28 bits the *arg test doesn't have solution # TODO: Make a test that test that print("description: mul_eint_int_arg_{0}bits".format(p)) print("program: |")