From 63334e138f997b2c76a56ff127342926b823c387 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 20 Dec 2022 16:17:47 +0100 Subject: [PATCH] fix: Fixing integer extension for plaintext encoding (close #847) --- .../Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp | 2 +- .../FHEToTFHEScalar/FHEToTFHEScalar.cpp | 2 +- .../Conversion/FHEToTFHECrt/add_eint_int.mlir | 2 +- .../Conversion/FHEToTFHECrt/conv2d.mlir | 2 +- .../Conversion/FHEToTFHECrt/sub_int_eint.mlir | 2 +- .../FHEToTFHEScalar/add_eint_int.mlir | 2 +- .../Conversion/FHEToTFHEScalar/conv2d.mlir | 2 +- .../FHEToTFHEScalar/sub_int_eint.mlir | 2 +- .../tests_cpu/end_to_end_fhe.yaml | 20 +++++++++++++++++++ 9 files changed, 28 insertions(+), 8 deletions(-) diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 4f8a08ffc..c57ec5db6 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -164,7 +164,7 @@ struct CrtOpPattern : public mlir::OpConversionPattern { mlir::Value writePlaintextCrtEncoding(mlir::Location location, mlir::Value rawPlaintext, mlir::PatternRewriter &rewriter) const { - mlir::Value castedPlaintext = rewriter.create( + mlir::Value castedPlaintext = rewriter.create( location, rewriter.getI64Type(), rawPlaintext); return rewriter.create( location, diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 79172bec0..d2254beb8 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -122,7 +122,7 @@ struct ScalarOpPattern : public mlir::OpConversionPattern { int64_t encryptedWidth, mlir::ConversionPatternRewriter &rewriter) const { int64_t intShift = 64 - 1 - encryptedWidth; - mlir::Value castedInt = rewriter.create( + mlir::Value castedInt = rewriter.create( location, rewriter.getIntegerType(64), rawPlaintext); mlir::Value constantShiftOp = rewriter.create( location, rewriter.getI64IntegerAttr(intShift)); diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir index 257bbce39..89f20b32d 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 -// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 +// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64 // CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> // CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> // CHECK-NEXT: %c0 = arith.constant 0 : index diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir index 020471298..ec45fd288 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir @@ -17,7 +17,7 @@ //CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3> //CHECK-NEXT: %c0_0 = arith.constant 0 : index //CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: %8 = arith.extui %6 : i3 to i64 +//CHECK-NEXT: %8 = arith.extsi %6 : i3 to i64 //CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> //CHECK-NEXT: %10 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>> //CHECK-NEXT: %c0_1 = arith.constant 0 : index diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir index fbb4fa1f1..42d0c5dd1 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 -// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 +// CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64 // CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> // CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> // CHECK-NEXT: %c0 = arith.constant 0 : index diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir index ae48bae5e..481577b39 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/add_eint_int.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @add_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 - // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 + // CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64 // CHECK-NEXT: %c56_i64 = arith.constant 56 : i64 // CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64 // CHECK-NEXT: %2 = "TFHE.add_glwe_int"(%arg0, %1) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir index 24a4e3bc6..2f05cac72 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/conv2d.mlir @@ -15,7 +15,7 @@ // CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) { // CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3> // CHECK-NEXT: %7 = tensor.extract %0[%arg3, %arg5, %arg7, %arg9] : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64 +// CHECK-NEXT: %8 = arith.extsi %6 : i3 to i64 // CHECK-NEXT: %c61_i64 = arith.constant 61 : i64 // CHECK-NEXT: %9 = arith.shli %8, %c61_i64 : i64 // CHECK-NEXT: %10 = "TFHE.add_glwe_int"(%7, %9) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir index 3ff9b3e47..2ec79d939 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHEScalar/sub_int_eint.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @sub_int_eint(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 - // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 + // CHECK-NEXT: %0 = arith.extsi %c1_i8 : i8 to i64 // CHECK-NEXT: %c56_i64 = arith.constant 56 : i64 // CHECK-NEXT: %1 = arith.shli %0, %c56_i64 : i64 // CHECK-NEXT: %2 = "TFHE.sub_int_glwe"(%1, %arg0) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml index 244270c95..334cea23d 100644 --- a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_fhe.yaml @@ -184,6 +184,26 @@ tests: - tensor: [9,4,7,7,10,9,9,4,7,7,10,9] shape: [4,3] --- +# Minimized bug 847 (CRT) +# https://github.com/zama-ai/concrete-compiler-internal/issues/847 +description: bug_847_crt +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + %c32768_i17 = arith.constant 32768 : i17 + %0 = "FHE.sub_eint_int"(%arg0, %c32768_i17) : (!FHE.eint<16>, i17) -> !FHE.eint<16> + return %0 : !FHE.eint<16> + } +encoding: crt +tests: + - inputs: + - scalar: 32769 + outputs: + - scalar: 1 + - inputs: + - scalar: 32770 + outputs: + - scalar: 2 +--- description: boolean_and program: | func.func @main(%arg0: !FHE.ebool, %arg1: !FHE.ebool) -> !FHE.ebool {