mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: fold add/sub/mul with specific constant values
Remove add/sub with constant zero values Remove mul with constant one values
This commit is contained in:
@@ -84,6 +84,8 @@ def AddEintIntOp : FHE_Op<"add_eint_int"> {
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::FHE::verifyAddEintIntOp(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def AddEintOp : FHE_Op<"add_eint"> {
|
||||
@@ -151,6 +153,8 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> {
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::FHE::verifySubIntEintOp(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def NegEintOp : FHE_Op<"neg_eint"> {
|
||||
@@ -217,6 +221,8 @@ def MulEintIntOp : FHE_Op<"mul_eint_int"> {
|
||||
let verifier = [{
|
||||
return ::mlir::concretelang::FHE::verifyMulEintIntOp(*this);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {
|
||||
|
||||
@@ -126,6 +126,45 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Avoid addition with constant 0
|
||||
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toAdd = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (toAdd != nullptr) {
|
||||
auto intToAdd = toAdd.getInt();
|
||||
if (intToAdd == 0) {
|
||||
return getOperand(0);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant 0
|
||||
OpFoldResult SubIntEintOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[0].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (toSub != nullptr) {
|
||||
auto intToSub = toSub.getInt();
|
||||
if (intToSub == 0) {
|
||||
return getOperand(1);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant 1
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toMul = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (toMul != nullptr) {
|
||||
auto intToMul = toMul.getInt();
|
||||
if (intToMul == 1) {
|
||||
return getOperand(0);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace FHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
func @mul_eint_int(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i8
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 2 : i8
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.mul_glwe_int"(%arg0, %[[V1]]) : (!TFHE.glwe<{_,_,_}{7}>, i8) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
|
||||
%0 = arith.constant 1 : i8
|
||||
%0 = arith.constant 2 : i8
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<7>, i8) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
|
||||
28
compiler/tests/Dialect/FHE/FHE/folding.mlir
Normal file
28
compiler/tests/Dialect/FHE/FHE/folding.mlir
Normal file
@@ -0,0 +1,28 @@
|
||||
// RUN: concretecompiler --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 0 : i3
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 0 : i3
|
||||
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 1 : i3
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
Reference in New Issue
Block a user