From 8aa6f3e809ed39e32558664c243536b6159aff1d Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 18 May 2022 11:36:58 +0100 Subject: [PATCH] feat: fold add/sub/mul with specific constant values Remove add/sub with constant zero values Remove mul with constant one values --- .../concretelang/Dialect/FHE/IR/FHEOps.td | 6 +++ compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 39 +++++++++++++++++++ .../FHEToTFHE/FHEToTFHE/mul_eint_int.mlir | 4 +- compiler/tests/Dialect/FHE/FHE/folding.mlir | 28 +++++++++++++ 4 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 compiler/tests/Dialect/FHE/FHE/folding.mlir diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 0665832de..c5c923ed0 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -84,6 +84,8 @@ def AddEintIntOp : FHE_Op<"add_eint_int"> { let verifier = [{ return ::mlir::concretelang::FHE::verifyAddEintIntOp(*this); }]; + + let hasFolder = 1; } def AddEintOp : FHE_Op<"add_eint"> { @@ -151,6 +153,8 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> { let verifier = [{ return ::mlir::concretelang::FHE::verifySubIntEintOp(*this); }]; + + let hasFolder = 1; } def NegEintOp : FHE_Op<"neg_eint"> { @@ -217,6 +221,8 @@ def MulEintIntOp : FHE_Op<"mul_eint_int"> { let verifier = [{ return ::mlir::concretelang::FHE::verifyMulEintIntOp(*this); }]; + + let hasFolder = 1; } def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> { diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index 8637773bd..c90af992d 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -126,6 +126,45 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, return mlir::success(); } +// Avoid addition with constant 0 +OpFoldResult AddEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toAdd = operands[1].dyn_cast_or_null(); + if (toAdd != nullptr) { + auto intToAdd = toAdd.getInt(); + if (intToAdd == 0) { + return getOperand(0); + } + } + return nullptr; +} + +// Avoid subtraction with constant 0 +OpFoldResult SubIntEintOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toSub = operands[0].dyn_cast_or_null(); + if (toSub != nullptr) { + auto intToSub = toSub.getInt(); + if (intToSub == 0) { + return getOperand(1); + } + } + return nullptr; +} + +// Avoid multiplication with constant 1 +OpFoldResult MulEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toMul = operands[1].dyn_cast_or_null(); + if (toMul != nullptr) { + auto intToMul = toMul.getInt(); + if (intToMul == 1) { + return getOperand(0); + } + } + return nullptr; +} + } // namespace FHE } // namespace concretelang } // namespace mlir diff --git a/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/mul_eint_int.mlir b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/mul_eint_int.mlir index 8e63c56a4..dd45dee1f 100644 --- a/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/mul_eint_int.mlir @@ -2,11 +2,11 @@ // CHECK-LABEL: func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { - // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8 + // CHECK-NEXT: %[[V1:.*]] = arith.constant 2 : i8 // CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}> // CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}> - %0 = arith.constant 1 : i8 + %0 = arith.constant 2 : i8 %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>) return %1: !FHE.eint<7> } diff --git a/compiler/tests/Dialect/FHE/FHE/folding.mlir b/compiler/tests/Dialect/FHE/FHE/folding.mlir new file mode 100644 index 000000000..a01a3ca46 --- /dev/null +++ b/compiler/tests/Dialect/FHE/FHE/folding.mlir @@ -0,0 +1,28 @@ +// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: return %arg0 : !FHE.eint<2> + + %0 = arith.constant 0 : i3 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// CHECK-LABEL: func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: return %arg0 : !FHE.eint<2> + + %0 = arith.constant 0 : i3 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: return %arg0 : !FHE.eint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +}