diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 8a0b2e718..d0954832c 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -26,7 +26,7 @@ def TensorBinaryEint : NativeOpTrait<"TensorBinaryEint">; def TensorUnaryEint : NativeOpTrait<"TensorUnaryEint">; -def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ @@ -81,7 +81,7 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRul let hasFolder = 1; } -def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers."; let description = [{ @@ -133,7 +133,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten ]; } -def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> { +def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [Pure, TensorBroadcastingRules, TensorBinaryIntEint]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers."; let description = [{ @@ -186,7 +186,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul ]; } -def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers."; let description = [{ @@ -242,7 +242,7 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRul let hasFolder = 1; } -def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers."; let description = [{ @@ -294,7 +294,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten ]; } -def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { +def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [Pure, TensorUnaryEint]> { let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers."; let description = [{ @@ -326,7 +326,7 @@ def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { ]; } -def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { +def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [Pure, TensorBroadcastingRules, TensorBinaryEintInt]> { let summary = "Returns a tensor that contains the multiplication of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ @@ -377,7 +377,7 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul let hasCanonicalizer = 1; } -def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [TensorBroadcastingRules, TensorBinaryEint]> { +def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [Pure, TensorBroadcastingRules, TensorBinaryEint]> { let summary = "Returns a tensor that contains the multiplication of two tensor of encrypted integers."; let description = [{ @@ -429,7 +429,7 @@ def FHELinalg_MulEintOp : FHELinalg_Op<"mul_eint", [TensorBroadcastingRules, Ten ]; } -def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> { +def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure]> { let summary = "Returns a tensor that contains the result of the lookup on a table."; let description = [{ @@ -465,7 +465,7 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> { let hasVerifier = 1; } -def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", []> { +def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure]> { let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element."; let description = [{ @@ -512,7 +512,7 @@ def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_tab let hasVerifier = 1; } -def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", []> { +def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_table", [Pure]> { let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map."; let description = [{ @@ -566,7 +566,7 @@ def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_t let hasVerifier = 1; } -def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> { +def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure]> { let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers."; let description = [{ @@ -590,7 +590,7 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> { } -def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint"> { +def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint", [Pure]> { let summary = "Returns the encrypted dot product between two vectors of encrypted integers."; let description = [{ @@ -614,7 +614,7 @@ def FHELinalg_DotEint : FHELinalg_Op<"dot_eint_eint"> { } -def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> { +def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [Pure, TensorBinaryEintInt]> { let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a matrix of clear integers."; let description = [{ @@ -759,7 +759,7 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEin }]; } -def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> { +def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [Pure, TensorBinaryIntEint]> { let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers."; let description = [{ @@ -905,7 +905,7 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt } -def FHELinalg_MatMulEintEintOp : FHELinalg_Op<"matmul_eint_eint", [TensorBinaryEint]> { +def FHELinalg_MatMulEintEintOp : FHELinalg_Op<"matmul_eint_eint", [Pure, TensorBinaryEint]> { let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of encrypted integers and a second matrix of encrypted integers."; let description = [{ @@ -1042,7 +1042,7 @@ def FHELinalg_MatMulEintEintOp : FHELinalg_Op<"matmul_eint_eint", [TensorBinaryE let hasVerifier = 1; } -def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { +def FHELinalg_SumOp : FHELinalg_Op<"sum", [Pure, TensorUnaryEint]> { let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes."; let description = [{ @@ -1125,7 +1125,7 @@ def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { let hasVerifier = 1; } -def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> { +def FHELinalg_ConcatOp : FHELinalg_Op<"concat", [Pure]> { let summary = "Concatenates a sequence of tensors along an existing axis."; let description = [{ @@ -1167,7 +1167,7 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> { let hasVerifier = 1; } -def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { +def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", [Pure]> { let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW"; let arguments = (ins Type.predicate, HasStaticShapePred]>>:$input, @@ -1183,7 +1183,7 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { let hasVerifier = 1; } -def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> { +def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", [Pure]> { let summary = "Returns the 2D maxpool of a tensor in the form NCHW"; let arguments = (ins Type.predicate, HasStaticShapePred]>>:$input, @@ -1195,7 +1195,7 @@ def FHELinalg_Maxpool2dOp : FHELinalg_Op<"maxpool2d", []> { let hasVerifier = 1; } -def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { +def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", [Pure]> { let summary = "Returns a tensor that contains the transposition of the input tensor."; let description = [{ @@ -1234,7 +1234,7 @@ def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { let hasVerifier = 1; } -def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", []> { +def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", [Pure]> { let summary = "Creates a tensor with a single element."; let description = [{ @@ -1251,7 +1251,7 @@ def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", []> { let hasVerifier = 1; } -def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", []> { +def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> { let summary = "Cast an unsigned integer tensor to a signed one"; let description = [{ @@ -1281,7 +1281,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", []> { let hasVerifier = 1; } -def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", []> { +def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> { let summary = "Cast a signed integer tensor to an unsigned one"; let description = [{ @@ -1311,7 +1311,7 @@ def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", []> { let hasVerifier = 1; } -def FHELinalg_RoundOp : FHELinalg_Op<"round", [TensorUnaryEint]> { +def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint]> { let summary = "Rounds a tensor of ciphertexts into a smaller precision."; let description = [{ diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 199f48b8e..e48c3d9f9 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -379,7 +379,7 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso // ----- -func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> tensor<4x3x!FHE.eint<7>> { +func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> (tensor<4x!FHE.eint<7>>, tensor<4x2x!FHE.eint<7>>, tensor<5x4x2x!FHE.eint<7>>, tensor<2x5x4x2x!FHE.eint<7>>) { // =============================== @@ -549,12 +549,12 @@ func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> tensor<4x3x!FHE. // =============================== - return %0 : tensor<4x3x!FHE.eint<7>> + return %2, %4, %6, %8 : tensor<4x!FHE.eint<7>>, tensor<4x2x!FHE.eint<7>>, tensor<5x4x2x!FHE.eint<7>>, tensor<2x5x4x2x!FHE.eint<7>> } // ----- -func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<7>> { +func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x!FHE.eint<7>> { // CHECK: {MANP = 1 : ui{{[0-9]+}}} %z = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>> %a = arith.constant dense<[[4, 6, 5], [2, 6, 3], [5, 6, 1], [5, 5, 3]]> : tensor<4x3xi8> @@ -574,7 +574,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint< // =============================== - return %0 : tensor<4x3x!FHE.eint<7>> + return %2 : tensor<4x!FHE.eint<7>> } ///////////////////////////////////////////////// @@ -662,7 +662,7 @@ func.func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso // ----- -func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { +func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> (tensor<2x!FHE.eint<7>>, tensor<2x2x!FHE.eint<7>>, tensor<5x2x2x!FHE.eint<7>>, tensor<2x5x2x2x!FHE.eint<7>>) { // =============================== @@ -816,12 +816,12 @@ func.func @matmul_int_eint_cst(%0: tensor<3x2x!FHE.eint<7>>) -> tensor<3x2x!FHE. // =============================== - return %0 : tensor<3x2x!FHE.eint<7>> + return %2, %4, %6, %8 : tensor<2x!FHE.eint<7>>, tensor<2x2x!FHE.eint<7>>, tensor<5x2x2x!FHE.eint<7>>, tensor<2x5x2x2x!FHE.eint<7>> } // ----- -func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<3x2x!FHE.eint<7>> { +func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<2x!FHE.eint<7>> { %z = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> %a = arith.constant dense<[[4, 6], [2, 6], [5, 6]]> : tensor<3x2xi8> @@ -840,190 +840,10 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<3x2x!FHE.eint< // =============================== - return %0 : tensor<3x2x!FHE.eint<7>> + return %2 : tensor<2x!FHE.eint<7>> } // ----- -func.func @sum() -> !FHE.eint<7> { - %0 = "FHE.zero_tensor"() : () -> tensor<5x3x4x2x!FHE.eint<7>> - - // CHECK: MANP = 11 : ui{{[0-9]+}} - %1 = "FHELinalg.sum"(%0) : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %2 = "FHELinalg.sum"(%0) { axes = [0] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %3 = "FHELinalg.sum"(%0) { axes = [1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %4 = "FHELinalg.sum"(%0) { axes = [2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %5 = "FHELinalg.sum"(%0) { axes = [3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %6 = "FHELinalg.sum"(%0) { axes = [0, 1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x2x!FHE.eint<7>> - - // CHECK: MANP = 5 : ui{{[0-9]+}} - %7 = "FHELinalg.sum"(%0) { axes = [0, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %8 = "FHELinalg.sum"(%0) { axes = [0, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %9 = "FHELinalg.sum"(%0) { axes = [1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x2x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %10 = "FHELinalg.sum"(%0) { axes = [1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %11 = "FHELinalg.sum"(%0) { axes = [2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> - - // CHECK: MANP = 8 : ui{{[0-9]+}} - %12 = "FHELinalg.sum"(%0) { axes = [0, 1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> - - // CHECK: MANP = 6 : ui{{[0-9]+}} - %13 = "FHELinalg.sum"(%0) { axes = [0, 1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> - - // CHECK: MANP = 7 : ui{{[0-9]+}} - %14 = "FHELinalg.sum"(%0) { axes = [0, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> - - // CHECK: MANP = 5 : ui{{[0-9]+}} - %15 = "FHELinalg.sum"(%0) { axes = [1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x!FHE.eint<7>> - - // CHECK: MANP = 11 : ui{{[0-9]+}} - %16 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> - - // CHECK: MANP = 11 : ui{{[0-9]+}} - %17 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %18 = "FHELinalg.sum"(%0) { axes = [0], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %19 = "FHELinalg.sum"(%0) { axes = [1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %20 = "FHELinalg.sum"(%0) { axes = [2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x2x!FHE.eint<7>> - - // CHECK: MANP = 2 : ui{{[0-9]+}} - %21 = "FHELinalg.sum"(%0) { axes = [3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x1x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %22 = "FHELinalg.sum"(%0) { axes = [0, 1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x2x!FHE.eint<7>> - - // CHECK: MANP = 5 : ui{{[0-9]+}} - %23 = "FHELinalg.sum"(%0) { axes = [0, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x2x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %24 = "FHELinalg.sum"(%0) { axes = [0, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x1x!FHE.eint<7>> - - // CHECK: MANP = 4 : ui{{[0-9]+}} - %25 = "FHELinalg.sum"(%0) { axes = [1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x2x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %26 = "FHELinalg.sum"(%0) { axes = [1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x1x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %27 = "FHELinalg.sum"(%0) { axes = [2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x1x!FHE.eint<7>> - - // CHECK: MANP = 8 : ui{{[0-9]+}} - %28 = "FHELinalg.sum"(%0) { axes = [0, 1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x2x!FHE.eint<7>> - - // CHECK: MANP = 6 : ui{{[0-9]+}} - %29 = "FHELinalg.sum"(%0) { axes = [0, 1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x1x!FHE.eint<7>> - - // CHECK: MANP = 7 : ui{{[0-9]+}} - %30 = "FHELinalg.sum"(%0) { axes = [0, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x1x!FHE.eint<7>> - - // CHECK: MANP = 5 : ui{{[0-9]+}} - %31 = "FHELinalg.sum"(%0) { axes = [1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x1x!FHE.eint<7>> - - // CHECK: MANP = 11 : ui{{[0-9]+}} - %32 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> - - // =============================== - - %35 = "FHE.zero_tensor"() : () -> tensor<2x0x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %36 = "FHELinalg.sum"(%35) : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %37 = "FHELinalg.sum"(%35) { axes = [0] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %38 = "FHELinalg.sum"(%35) { axes = [1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %39 = "FHELinalg.sum"(%35) { axes = [2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %40 = "FHELinalg.sum"(%35) { axes = [0, 1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %41 = "FHELinalg.sum"(%35) { axes = [0, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %42 = "FHELinalg.sum"(%35) { axes = [1, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %43 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2] } : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %44 = "FHELinalg.sum"(%35) { keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %45 = "FHELinalg.sum"(%35) { axes = [0], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %46 = "FHELinalg.sum"(%35) { axes = [1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %47 = "FHELinalg.sum"(%35) { axes = [2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x1x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %48 = "FHELinalg.sum"(%35) { axes = [0, 1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x3x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %49 = "FHELinalg.sum"(%35) { axes = [0, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x1x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %50 = "FHELinalg.sum"(%35) { axes = [1, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x1x!FHE.eint<7>> - - // CHECK: MANP = 1 : ui{{[0-9]+}} - %51 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> - - return %1 : !FHE.eint<7> -} - -// ----- - -func.func @concat() -> tensor<3x!FHE.eint<7>> { - %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> - // CHECK: MANP = 2 : ui{{[0-9]+}} - %1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> - - %2 = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<7>> - // CHECK: MANP = 3 : ui{{[0-9]+}} - %3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> - - %4 = "FHE.zero_tensor"() : () -> tensor<10x!FHE.eint<7>> - // CHECK: MANP = 4 : ui{{[0-9]+}} - %5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> - - // CHECK: MANP = 3 : ui{{[0-9]+}} - %6 = "FHELinalg.concat"(%1, %3) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> - // CHECK: MANP = 4 : ui{{[0-9]+}} - %7 = "FHELinalg.concat"(%1, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> - // CHECK: MANP = 4 : ui{{[0-9]+}} - %8 = "FHELinalg.concat"(%3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> - // CHECK: MANP = 4 : ui{{[0-9]+}} - %9 = "FHELinalg.concat"(%1, %3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> - - return %9 : tensor<3x!FHE.eint<7>> -} diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir new file mode 100644 index 000000000..61dcb304e --- /dev/null +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir @@ -0,0 +1,182 @@ +// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s + +func.func @sum() -> !FHE.eint<7> { + %0 = "FHE.zero_tensor"() : () -> tensor<5x3x4x2x!FHE.eint<7>> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %1 = "FHELinalg.sum"(%0) : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %2 = "FHELinalg.sum"(%0) { axes = [0] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %3 = "FHELinalg.sum"(%0) { axes = [1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %4 = "FHELinalg.sum"(%0) { axes = [2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %5 = "FHELinalg.sum"(%0) { axes = [3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %6 = "FHELinalg.sum"(%0) { axes = [0, 1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x2x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %7 = "FHELinalg.sum"(%0) { axes = [0, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %8 = "FHELinalg.sum"(%0) { axes = [0, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %9 = "FHELinalg.sum"(%0) { axes = [1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x2x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %10 = "FHELinalg.sum"(%0) { axes = [1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %11 = "FHELinalg.sum"(%0) { axes = [2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> + + // CHECK: MANP = 8 : ui{{[0-9]+}} + %12 = "FHELinalg.sum"(%0) { axes = [0, 1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + + // CHECK: MANP = 6 : ui{{[0-9]+}} + %13 = "FHELinalg.sum"(%0) { axes = [0, 1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + + // CHECK: MANP = 7 : ui{{[0-9]+}} + %14 = "FHELinalg.sum"(%0) { axes = [0, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %15 = "FHELinalg.sum"(%0) { axes = [1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x!FHE.eint<7>> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %16 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %17 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %18 = "FHELinalg.sum"(%0) { axes = [0], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %19 = "FHELinalg.sum"(%0) { axes = [1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %20 = "FHELinalg.sum"(%0) { axes = [2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %21 = "FHELinalg.sum"(%0) { axes = [3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x1x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %22 = "FHELinalg.sum"(%0) { axes = [0, 1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x2x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %23 = "FHELinalg.sum"(%0) { axes = [0, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x2x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %24 = "FHELinalg.sum"(%0) { axes = [0, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x1x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %25 = "FHELinalg.sum"(%0) { axes = [1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x2x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %26 = "FHELinalg.sum"(%0) { axes = [1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %27 = "FHELinalg.sum"(%0) { axes = [2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x1x!FHE.eint<7>> + + // CHECK: MANP = 8 : ui{{[0-9]+}} + %28 = "FHELinalg.sum"(%0) { axes = [0, 1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x2x!FHE.eint<7>> + + // CHECK: MANP = 6 : ui{{[0-9]+}} + %29 = "FHELinalg.sum"(%0) { axes = [0, 1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x1x!FHE.eint<7>> + + // CHECK: MANP = 7 : ui{{[0-9]+}} + %30 = "FHELinalg.sum"(%0) { axes = [0, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x1x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %31 = "FHELinalg.sum"(%0) { axes = [1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %32 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> + + // =============================== + + %35 = "FHE.zero_tensor"() : () -> tensor<2x0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %36 = "FHELinalg.sum"(%35) : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %37 = "FHELinalg.sum"(%35) { axes = [0] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %38 = "FHELinalg.sum"(%35) { axes = [1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %39 = "FHELinalg.sum"(%35) { axes = [2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %40 = "FHELinalg.sum"(%35) { axes = [0, 1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %41 = "FHELinalg.sum"(%35) { axes = [0, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %42 = "FHELinalg.sum"(%35) { axes = [1, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %43 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2] } : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %44 = "FHELinalg.sum"(%35) { keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %45 = "FHELinalg.sum"(%35) { axes = [0], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %46 = "FHELinalg.sum"(%35) { axes = [1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %47 = "FHELinalg.sum"(%35) { axes = [2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %48 = "FHELinalg.sum"(%35) { axes = [0, 1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %49 = "FHELinalg.sum"(%35) { axes = [0, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %50 = "FHELinalg.sum"(%35) { axes = [1, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %51 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + + return %1 : !FHE.eint<7> +} + +// ----- + +func.func @concat() -> tensor<3x!FHE.eint<7>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> + // CHECK: MANP = 2 : ui{{[0-9]+}} + %1 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + %2 = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<7>> + // CHECK: MANP = 3 : ui{{[0-9]+}} + %3 = "FHELinalg.sum"(%2) { keep_dims = true } : (tensor<5x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + %4 = "FHE.zero_tensor"() : () -> tensor<10x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %5 = "FHELinalg.sum"(%4) { keep_dims = true } : (tensor<10x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %6 = "FHELinalg.concat"(%1, %3) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %7 = "FHELinalg.concat"(%1, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %8 = "FHELinalg.concat"(%3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + // CHECK: MANP = 4 : ui{{[0-9]+}} + %9 = "FHELinalg.concat"(%1, %3, %5) : (tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>, tensor<1x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + return %9 : tensor<3x!FHE.eint<7>> +}