feat(compiler): Add canonicalization of FHE/FHELinalg to_signed to_unsigned ops

This commit is contained in:
Bourgerie Quentin
2023-06-23 11:35:53 +02:00
committed by Quentin Bourgerie
parent 5a80e22cf3
commit 5147ac8418
6 changed files with 106 additions and 0 deletions

View File

@@ -382,6 +382,7 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
let results = (outs FHE_EncryptedSignedIntegerType);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
@@ -407,6 +408,7 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
let results = (outs FHE_EncryptedIntegerType);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {

View File

@@ -1279,6 +1279,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> {
);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> {
@@ -1309,6 +1310,7 @@ def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> {
);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint]> {

View File

@@ -394,6 +394,39 @@ void MulEintIntOp::getCanonicalizationPatterns(
patterns.add<ZeroEncOpPattern>(context);
}
template <typename SignedConvOp>
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {
// Replace to_signed of zero to signed zero
class ZeroOpPattern : public mlir::OpRewritePattern<SignedConvOp> {
public:
ZeroOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<SignedConvOp>(context, 0) {}
mlir::LogicalResult
matchAndRewrite(SignedConvOp op,
mlir::PatternRewriter &rewriter) const override {
auto cstOp = op.getInput().template getDefiningOp<FHE::ZeroEintOp>();
if (cstOp == nullptr)
return mlir::failure();
rewriter.replaceOpWithNewOp<FHE::ZeroEintOp>(op,
op.getResult().getType());
return mlir::success();
}
};
patterns.add<ZeroOpPattern>(context);
}
void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {
getSignedConvCanonicalizationPatterns<ToSignedOp>(patterns, context);
}
void ToUnsignedOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
}
} // namespace FHE
} // namespace concretelang
} // namespace mlir

View File

@@ -1602,6 +1602,39 @@ void MatMulEintIntOp::getCanonicalizationPatterns(
getMatMulCanonicalizationPatterns<MatMulEintIntOp>(patterns, context);
}
template <typename SignedConvOp>
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {
// Replace to_signed of zero to signed zero
class ZeroOpPattern : public mlir::OpRewritePattern<SignedConvOp> {
public:
ZeroOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<SignedConvOp>(context, 0) {}
mlir::LogicalResult
matchAndRewrite(SignedConvOp op,
mlir::PatternRewriter &rewriter) const override {
auto cstOp = op.getInput().template getDefiningOp<FHE::ZeroTensorOp>();
if (cstOp == nullptr)
return mlir::failure();
rewriter.replaceOpWithNewOp<FHE::ZeroTensorOp>(op,
op.getResult().getType());
return mlir::success();
}
};
patterns.add<ZeroOpPattern>(context);
}
void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {
getSignedConvCanonicalizationPatterns<ToSignedOp>(patterns, context);
}
void ToUnsignedOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
}
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -54,3 +54,21 @@ func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
%1 = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<5>
return %1: !FHE.eint<5>
}
// CHECK: func.func @to_signed_zero() -> !FHE.esint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
// CHECK-NEXT: return %[[v0]]
func.func @to_signed_zero() -> !FHE.esint<7> {
%0 = "FHE.zero"() : () -> !FHE.eint<7>
%1 = "FHE.to_signed"(%0) : (!FHE.eint<7>) -> !FHE.esint<7>
return %1 : !FHE.esint<7>
}
// CHECK: func.func @to_unsigned_zero() -> !FHE.eint<7> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
// CHECK-NEXT: return %[[v0]]
func.func @to_unsigned_zero() -> !FHE.eint<7> {
%0 = "FHE.zero"() : () -> !FHE.esint<7>
%1 = "FHE.to_unsigned"(%0) : (!FHE.esint<7>) -> !FHE.eint<7>
return %1 : !FHE.eint<7>
}

View File

@@ -142,3 +142,21 @@ func.func @matmul_int_eint_encrypted_zero(%x: tensor<4x3xi3>) -> tensor<4x2x!FHE
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
return %0 : tensor<4x2x!FHE.eint<2>>
}
// CHECK: func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
// CHECK-NEXT: return %[[v0]]
func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> {
%0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>>
%1 = "FHELinalg.to_signed"(%0) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.esint<7>>
return %1 : tensor<4x!FHE.esint<7>>
}
// CHECK: func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
// CHECK-NEXT: return %[[v0]]
func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> {
%0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<7>>
%1 = "FHELinalg.to_unsigned"(%0) : (tensor<4x!FHE.esint<7>>) -> tensor<4x!FHE.eint<7>>
return %1 : tensor<4x!FHE.eint<7>>
}