diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index d83406db2..3883f6da5 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -382,6 +382,7 @@ def FHE_ToSignedOp : FHE_Op<"to_signed", [Pure]> { let results = (outs FHE_EncryptedSignedIntegerType); let hasVerifier = 1; + let hasCanonicalizer = 1; } def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> { @@ -407,6 +408,7 @@ def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [Pure]> { let results = (outs FHE_EncryptedIntegerType); let hasVerifier = 1; + let hasCanonicalizer = 1; } def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure]> { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index d0954832c..31dce47d8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -1279,6 +1279,7 @@ def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", [Pure]> { ); let hasVerifier = 1; + let hasCanonicalizer = 1; } def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> { @@ -1309,6 +1310,7 @@ def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", [Pure]> { ); let hasVerifier = 1; + let hasCanonicalizer = 1; } def FHELinalg_RoundOp : FHELinalg_Op<"round", [Pure, TensorUnaryEint]> { diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index e72a64d86..c0064c327 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -394,6 +394,39 @@ void MulEintIntOp::getCanonicalizationPatterns( patterns.add(context); } +template +void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + // Replace to_signed of zero to signed zero + class ZeroOpPattern : public mlir::OpRewritePattern { + public: + ZeroOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(SignedConvOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getInput().template getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + patterns.add(context); +} + +void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + getSignedConvCanonicalizationPatterns(patterns, context); +} + +void ToUnsignedOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + getSignedConvCanonicalizationPatterns(patterns, context); +} + } // namespace FHE } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 535303e47..8b60970eb 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1602,6 +1602,39 @@ void MatMulEintIntOp::getCanonicalizationPatterns( getMatMulCanonicalizationPatterns(patterns, context); } +template +void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + // Replace to_signed of zero to signed zero + class ZeroOpPattern : public mlir::OpRewritePattern { + public: + ZeroOpPattern(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context, 0) {} + + mlir::LogicalResult + matchAndRewrite(SignedConvOp op, + mlir::PatternRewriter &rewriter) const override { + auto cstOp = op.getInput().template getDefiningOp(); + if (cstOp == nullptr) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, + op.getResult().getType()); + return mlir::success(); + } + }; + patterns.add(context); +} + +void ToSignedOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + getSignedConvCanonicalizationPatterns(patterns, context); +} + +void ToUnsignedOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + getSignedConvCanonicalizationPatterns(patterns, context); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir index c9643f16a..16aa1e897 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/folding.mlir @@ -54,3 +54,21 @@ func.func @round(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { %1 = "FHE.round"(%arg0) : (!FHE.eint<5>) -> !FHE.eint<5> return %1: !FHE.eint<5> } + +// CHECK: func.func @to_signed_zero() -> !FHE.esint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() +// CHECK-NEXT: return %[[v0]] +func.func @to_signed_zero() -> !FHE.esint<7> { + %0 = "FHE.zero"() : () -> !FHE.eint<7> + %1 = "FHE.to_signed"(%0) : (!FHE.eint<7>) -> !FHE.esint<7> + return %1 : !FHE.esint<7> +} + +// CHECK: func.func @to_unsigned_zero() -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() +// CHECK-NEXT: return %[[v0]] +func.func @to_unsigned_zero() -> !FHE.eint<7> { + %0 = "FHE.zero"() : () -> !FHE.esint<7> + %1 = "FHE.to_unsigned"(%0) : (!FHE.esint<7>) -> !FHE.eint<7> + return %1 : !FHE.eint<7> +} diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir index d5d62e800..db55f3fa7 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/folding.mlir @@ -142,3 +142,21 @@ func.func @matmul_int_eint_encrypted_zero(%x: tensor<4x3xi3>) -> tensor<4x2x!FHE %0 = "FHELinalg.matmul_int_eint"(%x, %y): (tensor<4x3xi3>, tensor<3x2x!FHE.eint<2>>) -> tensor<4x2x!FHE.eint<2>> return %0 : tensor<4x2x!FHE.eint<2>> } + +// CHECK: func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @to_signed_zero() -> tensor<4x!FHE.esint<7>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<7>> + %1 = "FHELinalg.to_signed"(%0) : (tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.esint<7>> + return %1 : tensor<4x!FHE.esint<7>> +} + +// CHECK: func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() +// CHECK-NEXT: return %[[v0]] +func.func @to_unsigned_zero() -> tensor<4x!FHE.eint<7>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<7>> + %1 = "FHELinalg.to_unsigned"(%0) : (tensor<4x!FHE.esint<7>>) -> tensor<4x!FHE.eint<7>> + return %1 : tensor<4x!FHE.eint<7>> +}