diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index f758b24f1..c7a649d49 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -73,6 +73,8 @@ def AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRules, Tensor build($_builder, $_state, rhs.getType(), rhs, lhs); }]> ]; + + let hasFolder = 1; } def AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, TensorBinaryEint]> { @@ -178,6 +180,8 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor build($_builder, $_state, lhs.getType(), rhs, lhs); }]> ]; + + let hasFolder = 1; } def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { @@ -257,6 +261,8 @@ def MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tensor ); let results = (outs Type.predicate, HasStaticShapePred]>>); + + let hasFolder = 1; } def ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> { diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index e41d6f0d8..754db4878 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1641,6 +1641,48 @@ mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) { return mlir::success(); } +// Avoid addition with constant tensor of 0s +OpFoldResult AddEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toAdd = operands[1].dyn_cast_or_null(); + if (toAdd == nullptr) + return nullptr; + for (int64_t i = 0; i < toAdd.size(); i++) { + llvm::APInt cst = toAdd.getFlatValue(i); + if (cst != 0) + return nullptr; + } + return getOperand(0); +} + +// Avoid subtraction with constant tensor of 0s +OpFoldResult SubIntEintOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toSub = operands[0].dyn_cast_or_null(); + if (toSub == nullptr) + return nullptr; + for (int64_t i = 0; i < toSub.size(); i++) { + llvm::APInt cst = toSub.getFlatValue(i); + if (cst != 0) + return nullptr; + } + return getOperand(1); +} + +// Avoid multiplication with constant tensor of 1s +OpFoldResult MulEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toMul = operands[1].dyn_cast_or_null(); + if (toMul == nullptr) + return nullptr; + for (int64_t i = 0; i < toMul.size(); i++) { + llvm::APInt cst = toMul.getFlatValue(i); + if (cst != 1) + return nullptr; + } + return getOperand(0); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir new file mode 100644 index 000000000..a504a6d40 --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir @@ -0,0 +1,82 @@ +// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s + +// CHECK: func @add_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3> + %1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @add_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0]> : tensor<1xi3> + %1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @add_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>> +// CHECK-NEXT: } +func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { + %a1 = arith.constant dense<[[0]]> : tensor<1x1xi3> + %1 = "FHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>> + return %1: tensor<4x3x!FHE.eint<2>> +} + +// CHECK: func @sub_int_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3> + %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<4xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @sub_int_eint_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0]> : tensor<1xi3> + %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1xi3>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @sub_int_eint_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_int_eint_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { + %a1 = arith.constant dense<[[0]]> : tensor<1x1xi3> + %1 = "FHELinalg.sub_int_eint"(%a1, %a0) : (tensor<1x1xi3>, tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> + return %1: tensor<4x3x!FHE.eint<2>> +} + +// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi3> + %1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @mul_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[1]> : tensor<1xi3> + %1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @mul_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>> +// CHECK-NEXT: } +func @mul_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { + %a1 = arith.constant dense<[[1]]> : tensor<1x1xi3> + %1 = "FHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>> + return %1: tensor<4x3x!FHE.eint<2>> +}