From 670af021125769eb6b4fc68f9f6912554db7f891 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 9 Jun 2022 10:14:16 +0100 Subject: [PATCH] fix: remove subinteint folder it actually requires to negate the result which can't be done via standard folder, so we remove it can cause erroneous computation --- .../concretelang/Dialect/FHE/IR/FHEOps.td | 2 -- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 2 -- compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 13 --------- .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 14 ---------- compiler/tests/Dialect/FHE/FHE/folding.mlir | 9 ------- .../Dialect/FHELinalg/FHELinalg/folding.mlir | 27 ------------------- 6 files changed, 67 deletions(-) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index c5c923ed0..ee4bcf7cb 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -153,8 +153,6 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> { let verifier = [{ return ::mlir::concretelang::FHE::verifySubIntEintOp(*this); }]; - - let hasFolder = 1; } def NegEintOp : FHE_Op<"neg_eint"> { diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index c7a649d49..d7f3a23d1 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -180,8 +180,6 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor build($_builder, $_state, lhs.getType(), rhs, lhs); }]> ]; - - let hasFolder = 1; } def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index b0bed8e45..9bf86054f 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -139,19 +139,6 @@ OpFoldResult AddEintIntOp::fold(ArrayRef operands) { return nullptr; } -// Avoid subtraction with constant 0 -OpFoldResult SubIntEintOp::fold(ArrayRef operands) { - assert(operands.size() == 2); - auto toSub = operands[0].dyn_cast_or_null(); - if (toSub != nullptr) { - auto intToSub = toSub.getInt(); - if (intToSub == 0) { - return getOperand(1); - } - } - return nullptr; -} - // Avoid multiplication with constant 1 OpFoldResult MulEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index cd77499e8..f45abefce 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1655,20 +1655,6 @@ OpFoldResult AddEintIntOp::fold(ArrayRef operands) { return getOperand(0); } -// Avoid subtraction with constant tensor of 0s -OpFoldResult SubIntEintOp::fold(ArrayRef operands) { - assert(operands.size() == 2); - auto toSub = operands[0].dyn_cast_or_null(); - if (toSub == nullptr) - return nullptr; - for (int64_t i = 0; i < toSub.size(); i++) { - llvm::APInt cst = toSub.getFlatValue(i); - if (cst != 0) - return nullptr; - } - return getOperand(1); -} - // Avoid multiplication with constant tensor of 1s OpFoldResult MulEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/tests/Dialect/FHE/FHE/folding.mlir b/compiler/tests/Dialect/FHE/FHE/folding.mlir index a01a3ca46..796abb71b 100644 --- a/compiler/tests/Dialect/FHE/FHE/folding.mlir +++ b/compiler/tests/Dialect/FHE/FHE/folding.mlir @@ -9,15 +9,6 @@ func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } -// CHECK-LABEL: func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> -func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK-NEXT: return %arg0 : !FHE.eint<2> - - %0 = arith.constant 0 : i3 - %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>) - return %1: !FHE.eint<2> -} - // CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: return %arg0 : !FHE.eint<2> diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir index a504a6d40..ff600ddb7 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir @@ -27,33 +27,6 @@ func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FH return %1: tensor<4x3x!FHE.eint<2>> } -// CHECK: func @sub_int_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { -// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> -// CHECK-NEXT: } -func @sub_int_eint_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { - %a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3> - %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<4xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> - return %1: tensor<4x!FHE.eint<2>> -} - -// CHECK: func @sub_int_eint_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { -// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> -// CHECK-NEXT: } -func @sub_int_eint_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { - %a1 = arith.constant dense<[0]> : tensor<1xi3> - %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> - return %1: tensor<4x!FHE.eint<2>> -} - -// CHECK: func @sub_int_eint_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { -// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>> -// CHECK-NEXT: } -func @sub_int_eint_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { - %a1 = arith.constant dense<[[0]]> : tensor<1x1xi3> - %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1x1xi3>, tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> - return %1: tensor<4x3x!FHE.eint<2>> -} - // CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { // CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> // CHECK-NEXT: }