mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 03:25:05 -05:00
feat(compiler): Add canonicalization of FHE/FHELinalg to_signed to_unsigned ops
This commit is contained in:
committed by
Quentin Bourgerie
parent
5a80e22cf3
commit
5147ac8418
@@ -382,6 +382,7 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> {
|
||||
let results = (outs FHE_EncryptedSignedIntegerType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
|
||||
@@ -407,6 +408,7 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> {
|
||||
let results = (outs FHE_EncryptedIntegerType);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> {
|
||||
|
||||
@@ -1279,6 +1279,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> {
|
||||
@@ -1309,6 +1310,7 @@ def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint]> {
|
||||
|
||||
@@ -394,6 +394,39 @@ void MulEintIntOp::getCanonicalizationPatterns(
|
||||
patterns.add<ZeroEncOpPattern>(context);
|
||||
}
|
||||
|
||||
template <typename SignedConvOp>
|
||||
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
// Replace to_signed of zero to signed zero
|
||||
class ZeroOpPattern : public mlir::OpRewritePattern<SignedConvOp> {
|
||||
public:
|
||||
ZeroOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<SignedConvOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(SignedConvOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getInput().template getDefiningOp<FHE::ZeroEintOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroEintOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<ZeroOpPattern>(context);
|
||||
}
|
||||
|
||||
void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
getSignedConvCanonicalizationPatterns<ToSignedOp>(patterns, context);
|
||||
}
|
||||
|
||||
void ToUnsignedOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
|
||||
}
|
||||
|
||||
} // namespace FHE
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1602,6 +1602,39 @@ void MatMulEintIntOp::getCanonicalizationPatterns(
|
||||
getMatMulCanonicalizationPatterns<MatMulEintIntOp>(patterns, context);
|
||||
}
|
||||
|
||||
template <typename SignedConvOp>
|
||||
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
// Replace to_signed of zero to signed zero
|
||||
class ZeroOpPattern : public mlir::OpRewritePattern<SignedConvOp> {
|
||||
public:
|
||||
ZeroOpPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<SignedConvOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(SignedConvOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cstOp = op.getInput().template getDefiningOp<FHE::ZeroTensorOp>();
|
||||
if (cstOp == nullptr)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<FHE::ZeroTensorOp>(op,
|
||||
op.getResult().getType());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<ZeroOpPattern>(context);
|
||||
}
|
||||
|
||||
void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
getSignedConvCanonicalizationPatterns<ToSignedOp>(patterns, context);
|
||||
}
|
||||
|
||||
void ToUnsignedOp::getCanonicalizationPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
|
||||
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
|
||||
}
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -54,3 +54,21 @@ func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%1 = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
|
||||
// CHECK: func.func @to_signed_zero() -> !FHE.esint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @to_signed_zero() -> !FHE.esint<7> {
|
||||
%0 = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
%1 = "FHE.to_signed"(%0) : (!FHE.eint<7>) -> !FHE.esint<7>
|
||||
return %1 : !FHE.esint<7>
|
||||
}
|
||||
|
||||
// CHECK: func.func @to_unsigned_zero() -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @to_unsigned_zero() -> !FHE.eint<7> {
|
||||
%0 = "FHE.zero"() : () -> !FHE.esint<7>
|
||||
%1 = "FHE.to_unsigned"(%0) : (!FHE.esint<7>) -> !FHE.eint<7>
|
||||
return %1 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
@@ -142,3 +142,21 @@ func.func @matmul_int_eint_encrypted_zero(%x: tensor<4x3xi3>) -> tensor<4x2x!FHE
|
||||
%0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>>
|
||||
return %0 : tensor<4x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> {
|
||||
%0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>>
|
||||
%1 = "FHELinalg.to_signed"(%0) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.esint<7>>
|
||||
return %1 : tensor<4x!FHE.esint<7>>
|
||||
}
|
||||
|
||||
// CHECK: func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"()
|
||||
// CHECK-NEXT: return %[[v0]]
|
||||
func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> {
|
||||
%0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<7>>
|
||||
%1 = "FHELinalg.to_unsigned"(%0) : (tensor<4x!FHE.esint<7>>) -> tensor<4x!FHE.eint<7>>
|
||||
return %1 : tensor<4x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user