feat: fold add/sub/mul with specific constant values

Remove add/sub with constant zero values
Remove mul with constant one values
This commit is contained in:
youben11
2022-05-18 11:36:58 +01:00
committed by Ayoub Benaissa
parent b052157fae
commit 8aa6f3e809
4 changed files with 75 additions and 2 deletions

View File

@@ -84,6 +84,8 @@ def AddEintIntOp : FHE_Op<"add_eint_int"> {
let verifier = [{
return ::mlir::concretelang::FHE::verifyAddEintIntOp(*this);
}];
let hasFolder = 1;
}
def AddEintOp : FHE_Op<"add_eint"> {
@@ -151,6 +153,8 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> {
let verifier = [{
return ::mlir::concretelang::FHE::verifySubIntEintOp(*this);
}];
let hasFolder = 1;
}
def NegEintOp : FHE_Op<"neg_eint"> {
@@ -217,6 +221,8 @@ def MulEintIntOp : FHE_Op<"mul_eint_int"> {
let verifier = [{
return ::mlir::concretelang::FHE::verifyMulEintIntOp(*this);
}];
let hasFolder = 1;
}
def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {

View File

@@ -126,6 +126,45 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
return mlir::success();
}
// Avoid addition with constant 0
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toAdd = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
if (toAdd != nullptr) {
auto intToAdd = toAdd.getInt();
if (intToAdd == 0) {
return getOperand(0);
}
}
return nullptr;
}
// Avoid subtraction with constant 0
OpFoldResult SubIntEintOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toSub = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
if (toSub != nullptr) {
auto intToSub = toSub.getInt();
if (intToSub == 0) {
return getOperand(1);
}
}
return nullptr;
}
// Avoid multiplication with constant 1
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toMul = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
if (toMul != nullptr) {
auto intToMul = toMul.getInt();
if (intToMul == 1) {
return getOperand(0);
}
}
return nullptr;
}
} // namespace FHE
} // namespace concretelang
} // namespace mlir

View File

@@ -2,11 +2,11 @@
// CHECK-LABEL: func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
// CHECK-NEXT: %[[V1:.*]] = arith.constant 2 : i8
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
%0 = arith.constant 1 : i8
%0 = arith.constant 2 : i8
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
return %1: !FHE.eint<7>
}

View File

@@ -0,0 +1,28 @@
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
%0 = arith.constant 0 : i3
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
%0 = arith.constant 0 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}