diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 5f6f63275..d29ba721a 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -123,41 +123,6 @@ template struct CrtOpPattern : public mlir::OpRewritePattern { : mlir::OpRewritePattern(context, benefit), loweringParameters(params) {} - /// Writes an `scf::for` that loops over the crt dimension of two tensors and - /// execute the input lambda to write the loop body. Returns the first result - /// of the op. - /// - /// Note: - /// ----- - /// - /// + The type of `firstArgTensor` type is used as output type. - mlir::Value writeBinaryTensorLoop( - mlir::Location location, mlir::Value firstTensor, - mlir::Value secondTensor, mlir::PatternRewriter &rewriter, - mlir::function_ref - body) const { - - // Create the loop - mlir::arith::ConstantOp zeroConstantOp = - rewriter.create(location, 0); - mlir::arith::ConstantOp oneConstantOp = - rewriter.create(location, 1); - mlir::arith::ConstantOp crtSizeConstantOp = - rewriter.create(location, - loweringParameters.nMods); - mlir::scf::ForOp newOp = rewriter.create( - location, zeroConstantOp, crtSizeConstantOp, oneConstantOp, - mlir::ValueRange{firstTensor, secondTensor}, body); - - // Convert the types of the new operation - typing::TypeConverter converter(loweringParameters); - concretelang::convertOperandAndResultTypes(rewriter, newOp, - converter.getConversionLambda()); - - return newOp.getResult(0); - } - /// Writes an `scf::for` that loops over the crt dimension of one tensor and /// execute the input lambda to write the loop body. Returns the first result /// of the op. @@ -167,12 +132,16 @@ template struct CrtOpPattern : public mlir::OpRewritePattern { /// /// + The type of `firstArgTensor` type is used as output type. mlir::Value writeUnaryTensorLoop( - mlir::Location location, mlir::Value tensor, + mlir::Location location, mlir::Type returnType, mlir::PatternRewriter &rewriter, mlir::function_ref body) const { + mlir::Value tensor = rewriter.create( + location, returnType.cast(), + mlir::ValueRange{}); + // Create the loop mlir::arith::ConstantOp zeroConstantOp = rewriter.create(location, 0); @@ -239,20 +208,19 @@ struct AddEintIntOpPattern : public CrtOpPattern { converter.convertType(eintOperand.getType()) .cast() .getElementType(); - mlir::Value output = writeBinaryTensorLoop( - location, eintOperand, encodedPlaintextTensor, rewriter, + mlir::Value output = writeUnaryTensorLoop( + location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = - builder.create(loc, args[0], iter); - mlir::Value extractedInt = - builder.create(loc, args[1], iter); + builder.create(loc, eintOperand, iter); + mlir::Value extractedInt = builder.create( + loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); - builder.create( - loc, mlir::ValueRange{newTensor, args[1]}); + builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. @@ -292,20 +260,19 @@ struct SubIntEintOpPattern : public CrtOpPattern { converter.convertType(eintOperand.getType()) .cast() .getElementType(); - mlir::Value output = writeBinaryTensorLoop( - location, eintOperand, encodedPlaintextTensor, rewriter, + mlir::Value output = writeUnaryTensorLoop( + location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = - builder.create(loc, args[0], iter); - mlir::Value extractedInt = - builder.create(loc, args[1], iter); + builder.create(loc, eintOperand, iter); + mlir::Value extractedInt = builder.create( + loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedInt, extractedEint); mlir::Value newTensor = builder.create( loc, output, args[0], iter); - builder.create( - loc, mlir::ValueRange{newTensor, args[1]}); + builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. @@ -355,20 +322,19 @@ struct SubEintIntOpPattern : public CrtOpPattern { converter.convertType(eintOperand.getType()) .cast() .getElementType(); - mlir::Value output = writeBinaryTensorLoop( - location, eintOperand, encodedPlaintextTensor, rewriter, + mlir::Value output = writeUnaryTensorLoop( + location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedEint = - builder.create(loc, args[0], iter); - mlir::Value extractedInt = - builder.create(loc, args[1], iter); + builder.create(loc, eintOperand, iter); + mlir::Value extractedInt = builder.create( + loc, encodedPlaintextTensor, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedEint, extractedInt); mlir::Value newTensor = builder.create( loc, output, args[0], iter); - builder.create( - loc, mlir::ValueRange{newTensor, args[1]}); + builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. @@ -404,20 +370,19 @@ struct AddEintOpPattern : CrtOpPattern { converter.convertType(lhsOperand.getType()) .cast() .getElementType(); - mlir::Value output = writeBinaryTensorLoop( - location, lhsOperand, rhsOperand, rewriter, + mlir::Value output = writeUnaryTensorLoop( + location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = - builder.create(loc, args[0], iter); + builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = - builder.create(loc, args[1], iter); + builder.create(loc, rhsOperand, iter); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, extractedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); - builder.create( - loc, mlir::ValueRange{newTensor, args[1]}); + builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. @@ -453,22 +418,21 @@ struct SubEintOpPattern : CrtOpPattern { converter.convertType(lhsOperand.getType()) .cast() .getElementType(); - mlir::Value output = writeBinaryTensorLoop( - location, lhsOperand, rhsOperand, rewriter, + mlir::Value output = writeUnaryTensorLoop( + location, lhsOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedLhs = - builder.create(loc, args[0], iter); + builder.create(loc, lhsOperand, iter); mlir::Value extractedRhs = - builder.create(loc, args[1], iter); + builder.create(loc, rhsOperand, iter); mlir::Value negatedRhs = builder.create( loc, ciphertextScalarType, extractedRhs); mlir::Value output = builder.create( loc, ciphertextScalarType, extractedLhs, negatedRhs); mlir::Value newTensor = builder.create( loc, output, args[0], iter); - builder.create( - loc, mlir::ValueRange{newTensor, args[1]}); + builder.create(loc, mlir::ValueRange{newTensor}); }); // Rewrite original op. @@ -502,11 +466,11 @@ struct NegEintOpPattern : CrtOpPattern { .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( - location, operand, rewriter, + location, operand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = - builder.create(loc, args[0], iter); + builder.create(loc, operand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext); mlir::Value newTensor = builder.create( @@ -551,11 +515,11 @@ struct MulEintIntOpPattern : CrtOpPattern { .cast() .getElementType(); mlir::Value loopRes = writeUnaryTensorLoop( - location, eintOperand, rewriter, + location, eintOperand.getType(), rewriter, [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange args) { mlir::Value extractedCiphertext = - builder.create(loc, args[0], iter); + builder.create(loc, eintOperand, iter); mlir::Value negatedCiphertext = builder.create( loc, ciphertextScalarType, extractedCiphertext, encodedCleartext); mlir::Value newTensor = builder.create( diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir index f6d2441ad..99d991adc 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint.mlir @@ -1,19 +1,19 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-LABEL: func.func @add_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>, %arg1: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { +//CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-NEXT: %c0 = arith.constant 0 : index +//CHECK-NEXT: %c1 = arith.constant 1 : index +//CHECK-NEXT: %c5 = arith.constant 5 : index +//CHECK-NEXT: %1 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { +//CHECK-NEXT: %2 = tensor.extract %arg0[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-NEXT: %3 = tensor.extract %arg1[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-NEXT: %4 = "TFHE.add_glwe"(%2, %3) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> +//CHECK-NEXT: %5 = tensor.insert %4 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +//CHECK-NEXT: } +//CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> func.func @add_eint(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %c0 = arith.constant 0 : index - // CHECK-NEXT: %c1 = arith.constant 1 : index - // CHECK-NEXT: %c5 = arith.constant 5 : index - // CHECK-NEXT: %0:2 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { - // CHECK-NEXT: %1 = tensor.extract %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %2 = tensor.extract %arg4[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %3 = "TFHE.add_glwe"(%1, %2) : (!TFHE.glwe<{_,_,_}{7}>, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> - // CHECK-NEXT: %4 = tensor.insert %3 into %arg3[%arg2] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: scf.yield %4, %arg4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: } - // CHECK-NEXT: return %0#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir index ddc760db1..257bbce39 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/add_eint_int.mlir @@ -1,22 +1,22 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-LABEL: func.func @add_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { +// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 +// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 +// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> +// CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %3 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %2) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { +// CHECK-NEXT: %4 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %5 = tensor.extract %1[%arg1] : tensor<5xi64> +// CHECK-NEXT: %6 = "TFHE.add_glwe_int"(%4, %5) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}> +// CHECK-NEXT: %7 = tensor.insert %6 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: scf.yield %7 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } +// CHECK-NEXT: return %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> func.func @add_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 - // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 - // CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> - // CHECK-NEXT: %c0 = arith.constant 0 : index - // CHECK-NEXT: %c1 = arith.constant 1 : index - // CHECK-NEXT: %c5 = arith.constant 5 : index - // CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) { - // CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64> - // CHECK-NEXT: %5 = "TFHE.add_glwe_int"(%3, %4) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}> - // CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64> - // CHECK-NEXT: } - // CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - %0 = arith.constant 1 : i8 %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir index 0eac328dc..020471298 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/conv2d.mlir @@ -1,94 +1,99 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %c4 = arith.constant 4 : index -// CHECK-NEXT: %c100 = arith.constant 100 : index -// CHECK-NEXT: %c15 = arith.constant 15 : index -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index -// CHECK-NEXT: %c3 = arith.constant 3 : index -// CHECK-NEXT: %c14 = arith.constant 14 : index -// CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3> -// CHECK-NEXT: %c0_0 = arith.constant 0 : index -// CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %8 = arith.extui %6 : i3 to i64 -// CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> -// CHECK-NEXT: %c0_1 = arith.constant 0 : index -// CHECK-NEXT: %c1_2 = arith.constant 1 : index -// CHECK-NEXT: %c5 = arith.constant 5 : index -// CHECK-NEXT: %10:2 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %7, %arg13 = %9) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64>) { -// CHECK-NEXT: %12 = tensor.extract %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %13 = tensor.extract %arg13[%arg11] : tensor<5xi64> -// CHECK-NEXT: %14 = "TFHE.add_glwe_int"(%12, %13) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}> -// CHECK-NEXT: %15 = tensor.insert %14 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: scf.yield %15, %arg13 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5xi64> -// CHECK-NEXT: } -// CHECK-NEXT: %c0_3 = arith.constant 0 : index -// CHECK-NEXT: %11 = tensor.insert_slice %10#0 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: scf.yield %11 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13) -// CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15) -// CHECK-NEXT: %c0_0 = arith.constant 0 : index -// CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3> -// CHECK-NEXT: %c0_1 = arith.constant 0 : index -// CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64 -// CHECK-NEXT: %c0_2 = arith.constant 0 : index -// CHECK-NEXT: %c1_3 = arith.constant 1 : index -// CHECK-NEXT: %c5 = arith.constant 5 : index -// CHECK-NEXT: %15 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %11) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %19 = "TFHE.mul_glwe_int"(%18, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}> -// CHECK-NEXT: %20 = tensor.insert %19 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: scf.yield %20 : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: %c0_4 = arith.constant 0 : index -// CHECK-NEXT: %c1_5 = arith.constant 1 : index -// CHECK-NEXT: %c5_6 = arith.constant 5 : index -// CHECK-NEXT: %16:2 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %13, %arg19 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>>) { -// CHECK-NEXT: %18 = tensor.extract %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %19 = tensor.extract %arg19[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: %20 = "TFHE.add_glwe"(%18, %19) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> -// CHECK-NEXT: %21 = tensor.insert %20 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: scf.yield %21, %arg19 : tensor<5x!TFHE.glwe<{_,_,_}{2}>>, tensor<5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: %c0_7 = arith.constant 0 : index -// CHECK-NEXT: %17 = tensor.insert_slice %16#0 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: scf.yield %17 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> -// CHECK-NEXT: } -// CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> + +//CHECK-LABEL: func.func @conv2d(%arg0: tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %c4 = arith.constant 4 : index +//CHECK-NEXT: %c100 = arith.constant 100 : index +//CHECK-NEXT: %c15 = arith.constant 15 : index +//CHECK-NEXT: %c0 = arith.constant 0 : index +//CHECK-NEXT: %c1 = arith.constant 1 : index +//CHECK-NEXT: %c3 = arith.constant 3 : index +//CHECK-NEXT: %c14 = arith.constant 14 : index +//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %1 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %0) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %6 = tensor.extract %arg2[%arg5] : tensor<4xi3> +//CHECK-NEXT: %c0_0 = arith.constant 0 : index +//CHECK-NEXT: %7 = tensor.extract_slice %0[%arg3, %arg5, %arg7, %arg9, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %8 = arith.extui %6 : i3 to i64 +//CHECK-NEXT: %9 = "TFHE.encode_plaintext_with_crt"(%8) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> +//CHECK-NEXT: %10 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %c0_1 = arith.constant 0 : index +//CHECK-NEXT: %c1_2 = arith.constant 1 : index +//CHECK-NEXT: %c5 = arith.constant 5 : index +//CHECK-NEXT: %11 = scf.for %arg11 = %c0_1 to %c5 step %c1_2 iter_args(%arg12 = %10) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %13 = tensor.extract %7[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %14 = tensor.extract %9[%arg11] : tensor<5xi64> +//CHECK-NEXT: %15 = "TFHE.add_glwe_int"(%13, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %16 = tensor.insert %15 into %arg12[%arg11] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: scf.yield %16 : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: %c0_3 = arith.constant 0 : index +//CHECK-NEXT: %12 = tensor.insert_slice %11 into %arg10[%arg3, %arg5, %arg7, %arg9, %c0_3] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: scf.yield %12 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: %2 = scf.for %arg3 = %c0 to %c100 step %c1 iter_args(%arg4 = %1) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %3 = scf.for %arg5 = %c0 to %c4 step %c1 iter_args(%arg6 = %arg4) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %4 = scf.for %arg7 = %c0 to %c15 step %c1 iter_args(%arg8 = %arg6) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %5 = scf.for %arg9 = %c0 to %c15 step %c1 iter_args(%arg10 = %arg8) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %6 = scf.for %arg11 = %c0 to %c3 step %c1 iter_args(%arg12 = %arg10) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %7 = scf.for %arg13 = %c0 to %c14 step %c1 iter_args(%arg14 = %arg12) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %8 = scf.for %arg15 = %c0 to %c14 step %c1 iter_args(%arg16 = %arg14) -> (tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %9 = affine.apply #map(%arg7, %arg13) +//CHECK-NEXT: %10 = affine.apply #map(%arg9, %arg15) +//CHECK-NEXT: %c0_0 = arith.constant 0 : index +//CHECK-NEXT: %11 = tensor.extract_slice %arg0[%arg3, %arg11, %9, %10, %c0_0] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x3x28x28x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %12 = tensor.extract %arg1[%arg5, %arg11, %arg13, %arg15] : tensor<4x3x14x14xi3> +//CHECK-NEXT: %c0_1 = arith.constant 0 : index +//CHECK-NEXT: %13 = tensor.extract_slice %1[%arg3, %arg5, %arg7, %arg9, %c0_1] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> to tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %14 = arith.extsi %12 : i3 to i64 +//CHECK-NEXT: %15 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %c0_2 = arith.constant 0 : index +//CHECK-NEXT: %c1_3 = arith.constant 1 : index +//CHECK-NEXT: %c5 = arith.constant 5 : index +//CHECK-NEXT: %16 = scf.for %arg17 = %c0_2 to %c5 step %c1_3 iter_args(%arg18 = %15) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %20 = tensor.extract %11[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %21 = "TFHE.mul_glwe_int"(%20, %14) : (!TFHE.glwe<{_,_,_}{2}>, i64) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %22 = tensor.insert %21 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: scf.yield %22 : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: %17 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %c0_4 = arith.constant 0 : index +//CHECK-NEXT: %c1_5 = arith.constant 1 : index +//CHECK-NEXT: %c5_6 = arith.constant 5 : index +//CHECK-NEXT: %18 = scf.for %arg17 = %c0_4 to %c5_6 step %c1_5 iter_args(%arg18 = %17) -> (tensor<5x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: %20 = tensor.extract %13[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %21 = tensor.extract %16[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %22 = "TFHE.add_glwe"(%20, %21) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %23 = tensor.insert %22 into %arg18[%arg17] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: scf.yield %23 : tensor<5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: %c0_7 = arith.constant 0 : index +//CHECK-NEXT: %19 = tensor.insert_slice %18 into %arg16[%arg3, %arg5, %arg7, %arg9, %c0_7] [1, 1, 1, 1, 5] [1, 1, 1, 1, 1] : tensor<5x!TFHE.glwe<{_,_,_}{2}>> into tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: scf.yield %19 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %8 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %7 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %6 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %5 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %4 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: scf.yield %3 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: return %2 : tensor<100x4x15x15x5x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> return %1 : tensor<100x4x15x15x!FHE.eint<2>> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir index 632fe39b1..bbc98bb1b 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/mul_eint_int.mlir @@ -1,20 +1,20 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-LABEL: func.func @mul_eint_int(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { +// CHECK-NEXT: %c2_i8 = arith.constant 2 : i8 +// CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64 +// CHECK-NEXT: %1 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { +// CHECK-NEXT: %3 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %4 = "TFHE.mul_glwe_int"(%3, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}> +// CHECK-NEXT: %5 = tensor.insert %4 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: scf.yield %5 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } +// CHECK-NEXT: return %2 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> func.func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %c2_i8 = arith.constant 2 : i8 - // CHECK-NEXT: %0 = arith.extsi %c2_i8 : i8 to i64 - // CHECK-NEXT: %c0 = arith.constant 0 : index - // CHECK-NEXT: %c1 = arith.constant 1 : index - // CHECK-NEXT: %c5 = arith.constant 5 : index - // CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { - // CHECK-NEXT: %2 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%2, %0) : (!TFHE.glwe<{_,_,_}{7}>, i64) -> !TFHE.glwe<{_,_,_}{7}> - // CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: } - // CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - %0 = arith.constant 2 : i8 %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir index bda49d337..3d18ce33a 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/neg_eint.mlir @@ -1,18 +1,19 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-LABEL: func.func @neg_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %1 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { +// CHECK-NEXT: %2 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %3 = "TFHE.neg_glwe"(%2) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> +// CHECK-NEXT: %4 = tensor.insert %3 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: scf.yield %4 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } +// CHECK-NEXT: return %1 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } func.func @neg_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %c0 = arith.constant 0 : index - // CHECK-NEXT: %c1 = arith.constant 1 : index - // CHECK-NEXT: %c5 = arith.constant 5 : index - // CHECK-NEXT: %0 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { - // CHECK-NEXT: %1 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %2 = "TFHE.neg_glwe"(%1) : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> - // CHECK-NEXT: %3 = tensor.insert %2 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: scf.yield %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: } - // CHECK-NEXT: return %0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<7>) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } diff --git a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir index 4d863c514..fbb4fa1f1 100644 --- a/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir +++ b/compiler/tests/check_tests/Conversion/FHEToTFHECrt/sub_int_eint.mlir @@ -1,23 +1,23 @@ // RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// RUN: concretecompiler --action=dump-tfhe %s --large-integer-crt-decomposition=2,3,5,7,11 --large-integer-circuit-bootstrap=2,9 --large-integer-packing-keyswitch=694,1024,4,9 --v0-parameter=2,10,693,4,9,7,2 2>&1| FileCheck %s -// CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-LABEL: func.func @sub_int_eint(%arg0: tensor<5x!TFHE.glwe<{_,_,_}{7}>>) -> tensor<5x!TFHE.glwe<{_,_,_}{7}>> { +// CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 +// CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 +// CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> +// CHECK-NEXT: %2 = bufferization.alloc_tensor() : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %3 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %2) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>) { +// CHECK-NEXT: %4 = tensor.extract %arg0[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: %5 = tensor.extract %1[%arg1] : tensor<5xi64> +// CHECK-NEXT: %6 = "TFHE.sub_int_glwe"(%5, %4) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> +// CHECK-NEXT: %7 = tensor.insert %6 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: scf.yield %7 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } +// CHECK-NEXT: return %3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> +// CHECK-NEXT: } func.func @sub_int_eint(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %c1_i8 = arith.constant 1 : i8 - // CHECK-NEXT: %0 = arith.extui %c1_i8 : i8 to i64 - // CHECK-NEXT: %1 = "TFHE.encode_plaintext_with_crt"(%0) {mods = [2, 3, 5, 7, 11], modsProd = 2310 : i64} : (i64) -> tensor<5xi64> - // CHECK-NEXT: %c0 = arith.constant 0 : index - // CHECK-NEXT: %c1 = arith.constant 1 : index - // CHECK-NEXT: %c5 = arith.constant 5 : index - // CHECK-NEXT: %2:2 = scf.for %arg1 = %c0 to %c5 step %c1 iter_args(%arg2 = %arg0, %arg3 = %1) -> (tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64>) { - // CHECK-NEXT: %3 = tensor.extract %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: %4 = tensor.extract %arg3[%arg1] : tensor<5xi64> - // CHECK-NEXT: %5 = "TFHE.sub_int_glwe"(%4, %3) : (i64, !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> - // CHECK-NEXT: %6 = tensor.insert %5 into %arg2[%arg1] : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - // CHECK-NEXT: scf.yield %6, %arg3 : tensor<5x!TFHE.glwe<{_,_,_}{7}>>, tensor<5xi64> - // CHECK-NEXT: } - // CHECK-NEXT: return %2#0 : tensor<5x!TFHE.glwe<{_,_,_}{7}>> - %0 = arith.constant 1 : i8 %1 = "FHE.sub_int_eint"(%0, %arg0): (i8, !FHE.eint<7>) -> (!FHE.eint<7>) return %1: !FHE.eint<7>