From 4e0f0fa5b0c6a4c5b0eef537db6186066abbe96b Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 14 Feb 2023 15:18:28 +0100 Subject: [PATCH] fix(manp): Fixing computation of negative constant --- compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 3 +-- .../tests/check_tests/Dialect/FHE/Analysis/MANP.mlir | 9 ++++----- .../check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir | 9 +++++++++ 3 files changed, 14 insertions(+), 7 deletions(-) create mode 100644 compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 0f7ac4a48..73425913e 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -160,7 +160,6 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { // `unsigned` argument of `zext`. assert(i.getBitWidth() < std::numeric_limits::max() / 2 && "Required number of bits cannot be represented with an APInt"); - llvm::APInt ie = i.zext(2 * i.getBitWidth()); return ie * ie; @@ -168,7 +167,7 @@ static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { /// Calculates the square of the value of `i`. static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) { - llvm::APInt extI(2 * i.getActiveBits(), i.getSExtValue()); + auto extI = i.sext(2 * i.getBitWidth()); return extI * extI; } diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index 3195f1612..db61c2521 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -165,7 +165,7 @@ func.func @single_cst_mul_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -177,7 +177,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -189,7 +189,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -1 : i3 - // %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -199,8 +199,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // sqrt(1 + (2^2-1)^2) = 3 - // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir new file mode 100644 index 000000000..94c5b00cf --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir @@ -0,0 +1,9 @@ +// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s + +func.func @main(%arg0: tensor<1x10x!FHE.eint<33>>) -> tensor<1x1x!FHE.eint<33>> { + // sqrt(7282^2 + 20329^2 + 7232^2 + 32768 ^2 + 6446^2 + 32767^2 + 4708^2 + 20050^2 + 28812^2 + 17300^2) = 65277.528491817 + %cst_1 = arith.constant dense<[[-7282], [-20329], [-7232], [-32768], [6446], [32767], [-4708], [-20050], [-28812], [-17300]]> : tensor<10x1xi34> + // CHECK: MANP = 65278 + %2 = "FHELinalg.matmul_eint_int"(%arg0, %cst_1) : (tensor<1x10x!FHE.eint<33>>, tensor<10x1xi34>) -> tensor<1x1x!FHE.eint<33>> + return %2 : tensor<1x1x!FHE.eint<33>> +}